Metadata
====

In [None]:
import pathlib

parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/Scale images from WT_spp1_sost/TIFs/segmentations"
).expanduser()
assert parent_dir.exists()

segmentation_paths = sorted([str(x) for x in parent_dir.glob("*.tif")])
f"{len(segmentation_paths)} segmentations"

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scale_morphology.scales import metadata

df = metadata.df(segmentation_paths)

# Replace NaN magnifications with -1
df.loc[np.isnan(df["magnification"]), "magnification"] = -1

def plot_mdata(df):
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for axis, label in zip(axes, [c for c in df.columns if c != "path"]):
        axis.hist(df[label], bins=25)
        axis.set_title(label)
    
plot_mdata(df)

In [None]:
import seaborn as sns

# count occurrences of each combination of "age" and "sex"
heatmap_data = df.pivot_table(index="age", columns="sex", aggfunc="size", fill_value=0)

# Plot the heatmap
fig, axis = plt.subplots(1, 1, figsize=(6, 6))
sns.heatmap(heatmap_data, annot=True, fmt="d", ax=axis, cbar=False)
axis.set_ylabel("Age")
axis.set_xlabel("Sex")

In [None]:
import tifffile
from tqdm.notebook import tqdm
from scale_morphology.scales import efa, errors

coeff_dump = pathlib.Path("rough_carran_coeffs.npy")

if coeff_dump.is_file():
    coeffs = np.load(coeff_dump)
else:
    n_edge_points, order = 300, 50
    coeffs = []
    for scale in tqdm(
        [
            tifffile.imread(path).astype(np.uint8) * 255
            for path in tqdm(segmentation_paths)
        ]
    ):
        try:
            coeffs.append(efa.coefficients(scale, n_edge_points, order))
        except errors.BadImgError as e:
            coeffs.append(np.ones((order, 4)) * np.nan)
            print(f"\nError processing scale: {e}. NaN coeffs")
    coeffs = np.stack(coeffs)
    np.save(coeff_dump, coeffs)

In [None]:
"""
Sort values so that we get the right plotting order for our scatter plots
"""

flat_coeffs = coeffs.reshape((coeffs.shape[0], -1))
nan_mask = np.isnan(flat_coeffs).any(axis=1)

df_nan_removed = df[~nan_mask].copy()
flat_coeffs_nan_removed = flat_coeffs[~nan_mask]

df_nan_removed.sort_values(by="magnification", inplace=True, ascending=True)
df_nan_removed.sort_values(by="age", inplace=True)
df_nan_removed.sort_values(by="sex", inplace=True, ascending=True)
df_nan_removed["age"] = df_nan_removed["age"].astype(str)

In [None]:
"""
Create some labels for the sex/age encoding
"""

mf_mask = df_nan_removed["sex"] != "?"
mf_df = df_nan_removed[mf_mask].copy()
mf_coeffs = flat_coeffs_nan_removed[mf_mask]

# Want these groups - M/F age 7/12/18/40
groups = np.zeros(mf_coeffs.shape[0], dtype=int)

groups[mf_df["sex"] == "M"] += 8

groups[mf_df["age"] == "12"] += 1
groups[mf_df["age"] == "18"] += 2
groups[mf_df["age"] == "40"] += 4

encoding = {
    0: "F7",
    1: "F12",
    2: "F18",
    4: "F40",
    8: "M7",
    9: "M12",
    10: "M18",
    12: "M40",
}
colours = {
    0: "lightcoral",
    1: "indianred",
    2: "brown",
    4: "red",
    8: "cornflowerblue",
    9: "royalblue",
    10: "darkblue",
    12: "blue",
}

# Check we've correctly encoded it
mf_df.loc[:, "age_sex_group"] = groups
for val, group in mf_df.groupby("age_sex_group"):
    sexes = group["sex"].unique()
    ages = group["age"].unique()
    assert len(sexes) == 1, (val, sexes, ages)
    assert len(ages) == 1, (val, sexes, ages)

In [None]:
"""
Perform LDA
"""

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(mf_coeffs, groups)

In [None]:
from matplotlib import colors
from matplotlib.patches import Patch
from scipy.stats import gaussian_kde

fig, axes = plt.subplots(1, 5, figsize=(28, 6))


def clear2colour_cmap(colour) -> colors.Colormap:
    """
    Colormap that varies from clear to a colour
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba(colour, alpha=0.5)
    return colors.ListedColormap([c_white, c_black], f"clear2{colour}")


for i, axis in enumerate(axes):
    x = lda_coeffs[:, i]
    y = lda_coeffs[:, i + 1]

    # Grid for this pair
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    handles = []
    for g, label in encoding.items():
        mask = groups == g
        if mask.sum() < 2:
            continue  # KDE needs at least 2 points

        data = np.vstack([x[mask], y[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colours[g]))
        n = int(np.sum(mask))
        handles.append(Patch(color=colours[g], label=f"{label}, {n=}"))

        axis.scatter(*data, c=colours[g], s=5, marker="s", linewidth=0.5, edgecolor="k")

    axis.set_title(f"LD{i+1} vs LD{i+2}")

if handles:
    axes[0].legend(handles=handles, loc="upper left")