Metadata
====
First, we'll find how many scales were manually edited (and how):

In [None]:
import pathlib

parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/Scale images from WT_spp1_sost/TIFs/segmentations"
).expanduser()
assert parent_dir.exists()

segmentation_paths = sorted([str(x) for x in parent_dir.glob("*.tif")])
f"{len(segmentation_paths)} segmentations"

In [None]:
clean_seg_dir = parent_dir.parents[2] / "segmentations_cleaned"
clean_seg_paths = [clean_seg_dir / pathlib.Path(p).name for p in segmentation_paths]

for p in clean_seg_paths:
    assert p.exists()

In [None]:
"""
Read in pairs of the clean/raw segmentations. Invert the clean ones (whoops) and check them against the raw; if they don't match, keep track of them...
"""

import tifffile
import numpy as np
from tqdm.notebook import tqdm

scale_paths = [parent_dir.parent / p.name.replace("_segmentation.tif", ".tif") for p in clean_seg_paths]

scale_imgs = []
raw_segs = []
clean_segs = []
edited_names = []

for scale_path, raw_seg_path, clean_seg_path in zip(
    tqdm(scale_paths), segmentation_paths, clean_seg_paths
):
    clean_seg = tifffile.imread(clean_seg_path)
    raw_seg = tifffile.imread(raw_seg_path) * 255

    clean_segs.append(clean_seg)
    raw_segs.append(raw_seg)

    if (clean_seg == raw_seg).all():
        scale_imgs.append(None)
        continue

    edited_names.append(clean_seg_path.name)
    scale_imgs.append(tifffile.imread(scale_path))

In [None]:
"""
Show the edited scales
"""

import textwrap
import matplotlib.pyplot as plt
from matplotlib import colors


def clear_seismic() -> colors.Colormap:
    """
    Colormap that varies from clear to a colour
    """
    c_blue = colors.colorConverter.to_rgba("blue")
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_red = colors.colorConverter.to_rgba("red")
    return colors.ListedColormap([c_blue, c_white, c_red], f"clear2seismic")


fig, axes = plt.subplots(10, 10, figsize=(24, 24))

names = [n for n in edited_names if n is not None]

for axis, name in zip(axes.flat, names):
    # Find the index where it lives in the list
    i = 0
    for p in scale_paths:
        if p.name == name.replace("_segmentation.tif", ".tif"):
            break
        i += 1
    axis.imshow(scale_imgs[i])
    axis.imshow(clean_segs[i], alpha=0.5)
    axis.imshow(
        raw_segs[i] - clean_segs[i],
        cmap=clear_seismic(),
        vmin=-1,
        vmax=1,
        interpolation="none",
    )

    axis.set_title(
        "\n".join(textwrap.wrap(name, width=20)).strip(".tif").replace("_", " "),
        fontsize=8,
    )
    axis.set_axis_off()

fig.tight_layout()

In [None]:
len(scale_imgs)

In [None]:
# We don't need the non-clean masks any more
del segmentation_paths

Next we'll get some stats on the growth stages, sex, age, mutation etc.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scale_morphology.scales import metadata

df = metadata.df([str(x) for x in clean_seg_paths])

# Replace NaN magnifications with -1
df.loc[np.isnan(df["magnification"]), "magnification"] = -1

def plot_mdata(df):
    fig, axes = plt.subplots(1, 5, figsize=(16, 4))
    for axis, label in zip(axes, [c for c in df.columns if c != "path"]):
        axis.hist(df[label], bins=25)
        axis.set_title(label)
    
plot_mdata(df)

In [None]:
import seaborn as sns

# count occurrences of each combination of "age" and "sex"

# Plot the heatmap
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for axis, label1, label2 in zip(
    axes, ["age", "sex", "growth"], ["sex", "growth", "age"]
):
    heatmap_data = df.pivot_table(
        index=label1, columns=label2, aggfunc="size", fill_value=0
    )
    sns.heatmap(heatmap_data, annot=True, fmt="d", ax=axis, cbar=False)
    axis.set_ylabel(label1)
    axis.set_xlabel(label2)

In [None]:
import tifffile
from tqdm.notebook import tqdm
from scale_morphology.scales import efa, errors

coeff_dump = pathlib.Path("carran_coeffs.npy")

if coeff_dump.is_file():
    coeffs = np.load(coeff_dump)
else:
    n_edge_points, order = 300, 50
    coeffs = []
    for scale in tqdm(
        [
            tifffile.imread(path).astype(np.uint8) * 255
            for path in tqdm(clean_seg_paths)
        ]
    ):
        try:
            coeffs.append(efa.coefficients(scale, n_edge_points, order))
        except errors.BadImgError as e:
            coeffs.append(np.ones((order, 4)) * np.nan)
            print(f"\nError processing scale: {e}. NaN coeffs")
    coeffs = np.stack(coeffs)
    np.save(coeff_dump, coeffs)

In [None]:
"""
Sort values so that we get the right plotting order for our scatter plots
"""

flat_coeffs = coeffs.reshape((coeffs.shape[0], -1))
nan_mask = np.isnan(flat_coeffs).any(axis=1)

df_nan_removed = df[~nan_mask].copy()
flat_coeffs_nan_removed = flat_coeffs[~nan_mask]

df_nan_removed.sort_values(by="magnification", inplace=True, ascending=True)
df_nan_removed.sort_values(by="age", inplace=True)
df_nan_removed.sort_values(by="sex", inplace=True, ascending=True)
df_nan_removed["age"] = df_nan_removed["age"].astype(str)

In [None]:
"""
Create some labels for the sex/age encoding
"""

mf_mask = df_nan_removed["sex"] != "?"
mf_df = df_nan_removed[mf_mask].copy()
mf_coeffs = flat_coeffs_nan_removed[mf_mask]

# Want these groups - M/F age 7/12/18/40
groups = np.zeros(mf_coeffs.shape[0], dtype=int)

groups[mf_df["sex"] == "M"] += 8

groups[mf_df["age"] == "12"] += 1
groups[mf_df["age"] == "18"] += 2
groups[mf_df["age"] == "40"] += 4

encoding = {
    0: "F7",
    1: "F12",
    2: "F18",
    4: "F40",
    8: "M7",
    9: "M12",
    10: "M18",
    12: "M40",
}
colours = {
    0: "lightcoral",
    1: "indianred",
    2: "brown",
    4: "red",
    8: "cornflowerblue",
    9: "royalblue",
    10: "darkblue",
    12: "blue",
}

# Check we've correctly encoded it
mf_df.loc[:, "age_sex_group"] = groups
for val, group in mf_df.groupby("age_sex_group"):
    sexes = group["sex"].unique()
    ages = group["age"].unique()
    assert len(sexes) == 1, (val, sexes, ages)
    assert len(ages) == 1, (val, sexes, ages)

In [None]:
"""
Perform LDA
"""

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(mf_coeffs, groups)

In [None]:
from matplotlib import colors
from matplotlib.patches import Patch
from scipy.stats import gaussian_kde

fig, axes = plt.subplots(1, 5, figsize=(28, 6))


def clear2colour_cmap(colour) -> colors.Colormap:
    """
    Colormap that varies from clear to a colour
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba(colour, alpha=0.5)
    return colors.ListedColormap([c_white, c_black], f"clear2{colour}")


for i, axis in enumerate(axes):
    x = lda_coeffs[:, i]
    y = lda_coeffs[:, i + 1]

    # Grid for this pair
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    handles = []
    for g, label in encoding.items():
        mask = groups == g
        if mask.sum() < 2:
            continue  # KDE needs at least 2 points

        data = np.vstack([x[mask], y[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colours[g]))
        n = int(np.sum(mask))
        handles.append(Patch(color=colours[g], label=f"{label}, {n=}"))

        axis.scatter(*data, c=colours[g], s=5, marker="s", linewidth=0.5, edgecolor="k")

    axis.set_title(f"LD{i+1} vs LD{i+2}")

if handles:
    axes[0].legend(handles=handles, loc="upper left")

In [None]:
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    recall_score,
    precision_score,
)


def per_class_separability(X, y, cv=5, random_state=42, in_sample: bool = False):
    """
    per-class separability for LDA.

    Returns a DataFrame with per-class ROC-AUC (OvR), PR-AUC (OvR),
    recall, precision, support, and mean true-class posterior.
    """
    X = np.asarray(X)
    y = np.asarray(y)
    classes = np.unique(y)

    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=random_state)
    y_proba = np.zeros((len(y), len(classes)))
    y_pred = np.empty(len(y), dtype=classes.dtype)

    for train_idx, test_idx in skf.split(X, y):
        if in_sample:
            test_idx = train_idx

        lda = LinearDiscriminantAnalysis()
        lda.fit(X[train_idx], y[train_idx])

        # Align fold probabilities to a fixed class order
        fold_classes = lda.classes_
        align = np.array([np.where(fold_classes == c)[0][0] for c in classes])

        proba = lda.predict_proba(X[test_idx])[:, align]
        y_proba[test_idx] = proba
        y_pred[test_idx] = classes[np.argmax(proba, axis=1)]

    # One-vs-rest AUCs
    y_bin = label_binarize(y, classes=classes)
    auc_ovr = roc_auc_score(y_bin, y_proba, average=None)
    ap_ovr = average_precision_score(y_bin, y_proba, average=None)

    # Per-class precision/recall (hard predictions)
    recall = recall_score(y, y_pred, labels=classes, average=None, zero_division=0)
    precision = precision_score(
        y, y_pred, labels=classes, average=None, zero_division=0
    )

    # Mean posterior for the true class (soft, interpretable as “confidence”)
    true_post = np.array(
        [y_proba[i, np.where(classes == y[i])[0][0]] for i in range(len(y))]
    )
    mean_true_post = np.array([true_post[y == c].mean() for c in classes])

    support = np.array([(y == c).sum() for c in classes])

    out = (
        pd.DataFrame(
            {
                "support": support,
                "recall": recall,  # “X% of the time we correctly pick this class”
                "precision": precision,
                "roc_auc_ovr": auc_ovr,  # threshold-free separability vs rest
                "ap_ovr": ap_ovr,  # PR AUC (useful for imbalance)
                "mean_true_posterior": mean_true_post,  # avg P(class | x) for true class
            },
            index=classes,
        )
        .sort_index()
        .rename(index=encoding)
    )

    return out


# Out-of-sample metrics
per_class_separability(mf_coeffs, groups, cv=5, random_state=42)

In [None]:
# In-sample metrics
per_class_separability(mf_coeffs, groups, cv=5, random_state=42, in_sample=True)