Shape Analysis
====
This notebook runs through the shape analysis once the segmentation has been performed.

First we'll get some stats on the metadata: growth stages, sex, age, mutation etc. from the filepaths

In [None]:
import pathlib

parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/segmentations_cleaned"
).expanduser()
assert parent_dir.exists()

segmentation_paths = sorted([str(x) for x in parent_dir.glob("*.tif")])
f"{len(segmentation_paths)} segmentations"

In [None]:
import numpy as np

from scale_morphology.scales import metadata

df = metadata.df([str(x) for x in segmentation_paths])
df.head()

In [None]:
"""
Drop scales with missing data
"""

df = df[~df["no_scale"]]

In [None]:
"""
If necessary, read in the scales from the RDSF and perform EFA on them.

Otherwise read in the EFA coefficients from a cache
"""

from concurrent.futures import ThreadPoolExecutor

import tifffile
import numpy as np
from tqdm.notebook import tqdm
from skimage.measure import euler_number
from scipy.ndimage import binary_fill_holes

from scale_morphology.scales import efa
from scale_morphology.scales.segmentation import largest_connected_component


def load_scale_data(segmentation_path):
    """
    Returns the cleaned segmentations
    """
    scale = tifffile.imread(segmentation_path)
    if euler_number(scale) != 1:
        # Fill holes
        scale = binary_fill_holes(scale)
        # Remove small objects
        scale = (largest_connected_component(scale) * 255).astype(np.uint8)

        # It's possible we might have removed everything, so just make sure we haven't here
        if euler_number(scale) != 1:
            raise ValueError(f"Got {euler_number(scale)=}")

    return scale


coeff_dump = pathlib.Path("carran_coeffs.npy")

if coeff_dump.is_file():
    coeffs = np.load(coeff_dump)

else:
    n_edge_points, order = 300, 50

    with ThreadPoolExecutor(max_workers=32) as executor:
        scales = np.array(
            tqdm(
                executor.map(load_scale_data, segmentation_paths),
                total=len(scale_paths),
            )
        )

    coeffs = [efa.coefficients(scale, n_edge_points, order) for scale in tqdm(scales)]
    coeffs = np.stack(coeffs)
    np.save(coeff_dump, coeffs)

In [None]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
coeffs = scaler.fit_transform(coeffs)

In [None]:
"""
Sort values so that we get the right plotting order for our scatter plots
"""

flat_coeffs = coeffs.reshape((coeffs.shape[0], -1))

df.sort_values(by="magnification", inplace=True, ascending=True)
df.sort_values(by="age", inplace=True)
df.sort_values(by="sex", inplace=True, ascending=True)

# For the seaborn plotting - we want to encode as a categorical variable
df["age"] = df["age"].astype(str)

In [None]:
"""
Create some labels for the sex/age encoding
"""

mf_mask = df["sex"] != "?"
mf_df = df[mf_mask].copy()
mf_coeffs = flat_coeffs[mf_mask]

# Want these groups - M/F age 7/12/18/40
groups = np.zeros(mf_coeffs.shape[0], dtype=int)

groups[mf_df["sex"] == "M"] += 8

groups[mf_df["age"] == "12"] += 1
groups[mf_df["age"] == "18"] += 2
groups[mf_df["age"] == "40"] += 4

encoding = {
    0: "F7",
    1: "F12",
    2: "F18",
    4: "F40",
    8: "M7",
    9: "M12",
    10: "M18",
    12: "M40",
}
colours = {
    0: "lightcoral",
    1: "indianred",
    2: "brown",
    4: "red",
    8: "cornflowerblue",
    9: "royalblue",
    10: "darkblue",
    12: "blue",
}

# Check we've correctly encoded it
mf_df.loc[:, "age_sex_group"] = groups
for val, group in mf_df.groupby("age_sex_group"):
    sexes = group["sex"].unique()
    ages = group["age"].unique()
    assert len(sexes) == 1, (val, sexes, ages)
    assert len(ages) == 1, (val, sexes, ages)

In [None]:
"""
Perform LDA
"""

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(mf_coeffs, groups)

In [None]:
from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from scipy.stats import gaussian_kde

fig, axes = plt.subplots(1, 5, figsize=(28, 6))


def clear2colour_cmap(colour) -> colors.Colormap:
    """
    Colormap that varies from clear to a colour
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba(colour, alpha=0.5)
    return colors.ListedColormap([c_white, c_black], f"clear2{colour}")


for i, axis in enumerate(axes):
    x = lda_coeffs[:, i]
    y = lda_coeffs[:, i + 1]

    # Grid for this pair
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    handles = []
    for g, label in encoding.items():
        mask = groups == g
        if mask.sum() < 2:
            continue  # KDE needs at least 2 points

        data = np.vstack([x[mask], y[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colours[g]))
        n = int(np.sum(mask))
        handles.append(Patch(color=colours[g], label=f"{label}, {n=}"))

        axis.scatter(*data, c=colours[g], s=5, marker="s", linewidth=0.5, edgecolor="k")

    axis.set_title(f"LD{i+1} vs LD{i+2}")

if handles:
    axes[0].legend(handles=handles, loc="upper left")

In [None]:
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    recall_score,
    precision_score,
)


def per_class_separability(X, y, cv=5, random_state=42, in_sample: bool = False):
    """
    per-class separability for LDA.

    Returns a DataFrame with per-class ROC-AUC (OvR), PR-AUC (OvR),
    recall, precision, support, and mean true-class posterior.
    """
    X = np.asarray(X)
    y = np.asarray(y)
    classes = np.unique(y)

    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=random_state)
    y_proba = np.zeros((len(y), len(classes)))
    y_pred = np.empty(len(y), dtype=classes.dtype)

    for train_idx, test_idx in skf.split(X, y):
        if in_sample:
            test_idx = train_idx

        lda = LinearDiscriminantAnalysis()
        lda.fit(X[train_idx], y[train_idx])

        # Align fold probabilities to a fixed class order
        fold_classes = lda.classes_
        align = np.array([np.where(fold_classes == c)[0][0] for c in classes])

        proba = lda.predict_proba(X[test_idx])[:, align]
        y_proba[test_idx] = proba
        y_pred[test_idx] = classes[np.argmax(proba, axis=1)]

    # One-vs-rest AUCs
    y_bin = label_binarize(y, classes=classes)
    auc_ovr = roc_auc_score(y_bin, y_proba, average=None)
    ap_ovr = average_precision_score(y_bin, y_proba, average=None)

    # Per-class precision/recall (hard predictions)
    recall = recall_score(y, y_pred, labels=classes, average=None, zero_division=0)
    precision = precision_score(
        y, y_pred, labels=classes, average=None, zero_division=0
    )

    # Mean posterior for the true class (soft, interpretable as “confidence”)
    true_post = np.array(
        [y_proba[i, np.where(classes == y[i])[0][0]] for i in range(len(y))]
    )
    mean_true_post = np.array([true_post[y == c].mean() for c in classes])

    support = np.array([(y == c).sum() for c in classes])

    out = (
        pd.DataFrame(
            {
                "support": support,
                "recall": recall,  # “X% of the time we correctly pick this class”
                "precision": precision,
                "roc_auc_ovr": auc_ovr,  # threshold-free separability vs rest
                "ap_ovr": ap_ovr,  # PR AUC (useful for imbalance)
                "mean_true_posterior": mean_true_post,  # avg P(class | x) for true class
            },
            index=classes,
        )
        .sort_index()
        .rename(index=encoding)
    )

    return out


# Out-of-sample metrics
per_class_separability(mf_coeffs, groups, cv=5, random_state=42)

In [None]:
# In-sample metrics
per_class_separability(mf_coeffs, groups, cv=5, random_state=42, in_sample=True)

In [None]:
"""
Do the LDA using different labels
"""

m18_mask = (df["sex"] == "M") & (df["age"] == "18")
m18_df = df[m18_mask]
m18_coeffs = coeffs[m18_mask]

m18_df

In [None]:
# Want these groups - onto/regen wt/omd
assert set(np.unique(m18_df["mutation"])) == {"OMD", "WT"}
assert set(np.unique(m18_df["growth"])) == {np.inf, 10}

groups = np.zeros(len(m18_coeffs), dtype=int)

groups[m18_df["growth"] == np.inf] += 2
groups[m18_df["mutation"] == "WT"] += 1

encoding = {0: "Regen (10); OMD", 1: "Regen (10); WT", 2: "Onto; OMD", 3: "Onto; WT"}
colours = {0: "blue", 1: "red", 2: "darkblue", 3: "lightcoral"}

lda = LinearDiscriminantAnalysis()
lda_coeffs = lda.fit_transform(m18_coeffs, groups)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 6))

for i, axis in enumerate(axes):
    x = lda_coeffs[:, i]
    y = lda_coeffs[:, i + 1 if i < 2 else 0]

    # Grid for this pair
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    handles = []
    for g, label in encoding.items():
        mask = groups == g
        if mask.sum() < 2:
            continue  # KDE needs at least 2 points

        data = np.vstack([x[mask], y[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colours[g]))
        n = int(np.sum(mask))
        handles.append(Patch(color=colours[g], label=f"{label}, {n=}"))

        axis.scatter(*data, c=colours[g], s=5, marker="s", linewidth=0.5, edgecolor="k")

    axis.set_title(f"LD{i+1} vs LD{i+2}")

axes[0].legend(handles=handles)

fig.suptitle("M18")