Metadata
====

In [None]:
import pathlib

parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/Scale images from WT_spp1_sost/TIFs/segmentations"
).expanduser()
assert parent_dir.exists()

segmentation_paths = sorted([str(x) for x in parent_dir.glob("*.tif")])
f"{len(segmentation_paths)} segmentations"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scale_morphology.scales import metadata

df = metadata.df(segmentation_paths)

# Replace NaN magnifications with -1
df.loc[np.isnan(df["magnification"]), "magnification"] = -1

def plot_mdata(df):
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for axis, label in zip(axes, [c for c in df.columns if c != "path"]):
        axis.hist(df[label], bins=25)
        axis.set_title(label)
    
plot_mdata(df)

In [None]:
plot_mdata(df[df["magnification"] != -1])
plt.gcf().suptitle("High-mag ones only")

In [None]:
import seaborn as sns

# Create a pivot table to count occurrences of each combination of "age" and "sex"
heatmap_data = df.pivot_table(
    index="age", columns="sex", aggfunc="size", fill_value=0
)

# Plot the heatmap
plt.figure(figsize=(10, 6))
sns.heatmap(heatmap_data, annot=True, fmt="d", cbar=True)
plt.title("Counts of Age and Sex")
plt.ylabel("Age")
plt.xlabel("Sex")
plt.show()

In [None]:
import tifffile
from tqdm.notebook import tqdm
from scale_morphology.scales import efa, errors

scales = [
    tifffile.imread(path).astype(np.uint8) * 255 for path in tqdm(segmentation_paths)
]

coeff_dump = pathlib.Path("rough_carran_coeffs.npy")

if coeff_dump.is_file():
    coeffs = np.load(coeff_dump)
else:
    n_edge_points, order = 300, 50
    coeffs = []
    for scale in tqdm(scales):
        try:
            coeffs.append(efa.coefficients(scale, n_edge_points, order))
        except errors.BadImgError as e:
            coeffs.append(np.ones((order, 4)) * np.nan)
            print(f"\nError processing scale: {e}. NaN coeffs")
    coeffs = np.stack(coeffs)
    np.save(coeff_dump, coeffs)

In [None]:
from sklearn.decomposition import PCA

flat_coeffs = coeffs.reshape((coeffs.shape[0], -1))
nan_mask = np.isnan(flat_coeffs).any(axis=1)
flat_coeffs_nan_removed = flat_coeffs[~nan_mask]

pca = PCA(n_components=2)
pca_coeffs = np.ascontiguousarray(pca.fit_transform(flat_coeffs_nan_removed))

In [None]:
df_nan_removed = df[~nan_mask].copy()
df_nan_removed.loc[:, "pca1"] = pca_coeffs[:, 0]
df_nan_removed.loc[:, "pca2"] = pca_coeffs[:, 1]

df_nan_removed.sort_values(by="magnification", inplace=True, ascending=True)
df_nan_removed.sort_values(by="age", inplace=True)
df_nan_removed.sort_values(by="sex", inplace=True, ascending=True)
df_nan_removed["age"] = df_nan_removed["age"].astype(str)

In [None]:
covariates = ["age", "sex", "magnification"]
cmaps = ["magma", "Set1", "Set1"]
fig, axes = plt.subplots(1, 3, figsize=(12, 6))

for axis, hue, cmap in zip(axes, covariates, cmaps, strict=True):
    sns.scatterplot(
        df_nan_removed,
        x="pca1",
        y="pca2",
        hue=hue,
        ax=axis,
        alpha=1,
        s=10,
        palette=cmap,
    )

In [None]:
from scale_morphology.scales import dashboard

scales_nan_removed = []
names_nan_removed = []
for scale, name, x in zip(scales, df["path"], nan_mask):
    if not x:
        scales_nan_removed.append(scale)
        names_nan_removed.append(pathlib.Path(name).stem)


dashboard_dir = pathlib.Path("rough_carran_dashboards/")
dashboard_dir.mkdir(exist_ok=True)

age_db = dashboard_dir / "pca_age.html"
if not age_db.is_file():
    dashboard.write_dashboard(
        pca_coeffs,
        scales_nan_removed,
        df_nan_removed["age"],
        names_nan_removed,
        age_db,
        "PCA - age",
    )

sex_db = dashboard_dir / "pca_sex.html"
if not sex_db.is_file():
    dashboard.write_dashboard(
        pca_coeffs,
        scales_nan_removed,
        df_nan_removed["sex"],
        names_nan_removed,
        sex_db,
        "PCA - sex",
    )

In [None]:
"""
Repeat with LDA
"""

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(flat_coeffs_nan_removed, df_nan_removed["sex"])


In [None]:
fig, axis = plt.subplots()

plot_kw = {"s": 6}

axis.scatter(
    *lda_coeffs[df_nan_removed["sex"] == "?"].T, **plot_kw, color="grey", label="?"
)
axis.scatter(
    *lda_coeffs[df_nan_removed["sex"] == "M"].T, **plot_kw, color="blue", label="M"
)
axis.scatter(
    *lda_coeffs[df_nan_removed["sex"] == "F"].T, **plot_kw, color="red", label="F"
)

axis.legend()

In [None]:
"""
KDE
"""

from scipy.stats import gaussian_kde
from matplotlib import colors
from matplotlib.patches import Patch

xmax, ymax = lda_coeffs.max(axis=0)
xmin, ymin = lda_coeffs.min(axis=0)

values = lda_coeffs.T
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([xx.ravel(), yy.ravel()])


def clear2colour_cmap(colour) -> colors.Colormap:
    """
    Colormap that varies from clear to black
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba(colour, alpha=0.5)
    return colors.ListedColormap([c_white, c_black], f"clear2{colour}")


def plot_kde(values, axis, cmap):
    x, y = values

    kernel = gaussian_kde(values)
    f = np.reshape(kernel(positions).T, xx.shape)
    axis.contourf(xx, yy, f, cmap=cmap)


fix, axis = plt.subplots()

axis.set_xlim(xmin, xmax)
axis.set_ylim(ymin, ymax)

f = plot_kde(values.T[df_nan_removed["sex"] == "F"].T, axis, clear2colour_cmap("r"))
f = plot_kde(values.T[df_nan_removed["sex"] == "M"].T, axis, clear2colour_cmap("b"))
f = plot_kde(values.T[df_nan_removed["sex"] == "?"].T, axis, clear2colour_cmap("grey"))
axis.legend(
    handles=[
        Patch(color="b", label="M"),
        Patch(color="r", label="F"),
        Patch(color="grey", label="?"),
    ]
)

In [None]:
lda_db_sex = dashboard_dir / "lda_sex.html"
if not lda_db_sex.is_file():
    dashboard.write_dashboard(
        lda_coeffs,
        scales_nan_removed,
        df_nan_removed["sex"],
        names_nan_removed,
        lda_db_sex,
        "LDA - sex",
    )

In [None]:
"""
Ignore the unknown sex ones
"""

colours = {
    ("7", "F"): "lightcoral",
    ("12", "F"): "indianred",
    ("18", "F"): "brown",
    ("40", "F"): "red",
    ("7", "M"): "cornflowerblue",
    ("12", "M"): "royalblue",
    ("18", "M"): "darkblue",
    ("40", "M"): "blue",
}

fig, axis = plt.subplots()

sexed_df = df_nan_removed.copy()
sexed_df[["lda1", "lda2"]] = values.T
sexed_df = sexed_df[sexed_df["sex"] != "?"]

handles = []
for (age, sex), group in sexed_df.groupby(["age", "sex"]):
    colour = colours[(age, sex)]
    plot_kde(
        group[["lda1", "lda2"]].to_numpy().T,
        axis,
        clear2colour_cmap(colour),
    )
    n = len(group)
    handles.append(Patch(color=colour, label=f"{sex}, {age}, {n=}"))
print(handles)
axis.legend(handles=handles, loc="upper left")

In [None]:
"""
Do LDA again but split by age
"""

mf_mask = df_nan_removed["sex"] != "?"
mf_df = df_nan_removed[mf_mask].copy()
mf_coeffs = flat_coeffs_nan_removed[mf_mask]

# Want these groups - M/F age 7/12/18/40
groups = np.zeros(mf_coeffs.shape[0], dtype=int)

groups[mf_df["sex"] == "M"] += 8

groups[mf_df["age"] == "12"] += 1
groups[mf_df["age"] == "18"] += 2
groups[mf_df["age"] == "40"] += 4

encoding = {
    0: "F7",
    1: "F12",
    2: "F18",
    4: "F40",
    8: "M7",
    9: "M12",
    10: "M18",
    12: "M40",
}

# Check we've correctly encoded it
mf_df.loc[:, "age_sex_group"] = groups
for val, group in mf_df.groupby("age_sex_group"):
    sexes = group["sex"].unique()
    ages = group["age"].unique()
    assert len(sexes) == 1, (val, sexes, ages)
    assert len(ages) == 1, (val, sexes, ages)

In [None]:
lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(mf_coeffs, groups)
lda_coeffs.shape

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

colours = {
    0: "lightcoral",
    1: "indianred",
    2: "brown",
    4: "red",
    8: "cornflowerblue",
    9: "royalblue",
    10: "darkblue",
    12: "blue",
}

for i, axis in enumerate(axes):
    for j in encoding:
        axis.scatter(
            lda_coeffs[groups == j].T[i],
            lda_coeffs[groups == j].T[i + 1],
            c=colours[j],
            label=encoding[j],
            s=6,
        )

axes[0].legend()

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, axis in enumerate(axes):
    x = lda_coeffs[:, i]
    y = lda_coeffs[:, i + 1]

    # Grid for this pair
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    handles = []
    for g, label in encoding.items():
        mask = groups == g
        if mask.sum() < 2:
            continue  # KDE needs at least 2 points

        data = np.vstack([x[mask], y[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colours[g]))
        n = int(np.sum(mask))
        handles.append(Patch(color=colours[g], label=f"{label}, {n=}"))

    axis.set_title(f"LD{i+1} vs LD{i+2}")

if handles:
    axes[0].legend(handles=handles, loc="upper left")