Shape Analysis
====
This notebook runs through the shape analysis once the segmentation has been performed.

First we'll get some stats on the metadata: growth stages, sex, age, mutation etc. from the filepaths

In [None]:
import pathlib

parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/segmentations_cleaned"
).expanduser()
assert parent_dir.exists()

segmentation_paths = sorted([str(x) for x in parent_dir.glob("*.tif")])
f"{len(segmentation_paths)} segmentations"

In [None]:
import numpy as np

from scale_morphology.scales import metadata

df = metadata.df([str(x) for x in segmentation_paths])
df.describe(include="all")

In [None]:
"""
Drop scales with missing data
"""

df = df[~df["no_scale"]]

In [None]:
"""
If necessary, read in the scales from the RDSF and perform EFA on them.

Otherwise read in the EFA coefficients from a cache
"""

from concurrent.futures import ThreadPoolExecutor

import tifffile
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from skimage.measure import euler_number
from scipy.ndimage import binary_fill_holes

from scale_morphology.scales import efa
from scale_morphology.scales.segmentation import largest_connected_component


def load_scale_data(segmentation_path):
    """
    Returns the cleaned segmentations
    """
    scale = tifffile.imread(segmentation_path)
    if euler_number(scale) != 1:
        # Fill holes
        scale = binary_fill_holes(scale)
        # Remove small objects
        scale = (largest_connected_component(scale) * 255).astype(np.uint8)

        # It's possible we might have removed everything, so just make sure we haven't here
        if euler_number(scale) != 1:
            raise ValueError(f"Got {euler_number(scale)=}")

    return scale


coeff_dump = pathlib.Path("carran_coeffs.npy")

if coeff_dump.is_file():
    coeffs = np.load(coeff_dump)

else:
    n_edge_points, order = 300, 50

    with ThreadPoolExecutor(max_workers=32) as executor:
        scales = np.array(
            list(
                tqdm(
                    executor.map(load_scale_data, df["path"]),
                    total=len(df),
                )
            )
        )

    # The default magnification is 4.0
    magnifications = df["magnification"]
    magnifications[pd.isna(magnifications)] = 4.0

    coeffs = [
        efa.coefficients(scale, n_edge_points, order, magnification=4 / magnification)
        for scale, magnification in zip(tqdm(scales), magnifications)
    ]
    coeffs = np.stack(coeffs)
    np.save(coeff_dump, coeffs)

In [None]:
"""
Run PCA on the coefficients and plot the contributions
"""

from matplotlib import colors
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


def _heatmap(scalings):
    fig, axis = plt.subplots(figsize=(10, 2))

    vmax = np.max(np.abs(scalings))
    im = axis.imshow(
        scalings.T,
        aspect="auto",
        cmap="seismic",
        norm=colors.CenteredNorm(vcenter=0.0),
    )
    axis.set_yticks(range(scalings.shape[1]))
    fig.colorbar(im, ax=axis)


def _pca_barplot(scalings):
    n_bars = min(14, scalings.shape[1])
    fig, axes = plt.subplots(1, scalings.shape[1], figsize=(scalings.shape[1] * 4, 4))

    # First two ticks for size, d_1; then calculate the rest + build their labels
    xticks = [-3, -1]
    xticklabels = ["size", r"$d_1$"]
    vlines = [-2, 0]
    for i in range(n_bars // 4):
        xticklabels += [rf"$a_{i+2}$", rf"$b_{i+2}$", rf"$c_{i+2}$", rf"$d_{i+2}$"]

        start, end = 1 + 5 * i, 5 + 5 * i
        xticks += list(range(start, end))

        vlines.append(start - 1)

    for axis, scaling in zip(axes, scalings.T, strict=True):
        axis.bar(xticks, scaling[:n_bars])

        axis.set_xticks(
            xticks,
            xticklabels,
            ha="right",
        )

        # Separate out the different conceptual bits
        for v in vlines:
            axis.axvline(v, color="k", linestyle="--")

        axis.set_ylim(-1.1, 1.1)


def _plot_vectors(scalings: np.ndarray):
    """
    Plot eigenvectors for a dimensionality reduction thing
    """
    _heatmap(scalings)
    _pca_barplot(scalings)

In [None]:
"""
Run LDA on the PCA coefficients, given some categories, and plot a scatter plot
"""

from matplotlib import colors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

from scipy.stats import gaussian_kde
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def clear2colour_cmap(colour) -> colors.Colormap:
    """
    Colormap that varies from clear to a colour
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba(colour, alpha=0.5)
    return colors.ListedColormap([c_white, c_black], f"clear2{colour}")


def _plot_kde_scatter(
    axis: plt.Axes,
    x_coeffs: np.ndarray,
    y_coeffs: np.ndarray,
    lda_labels,
    colours,
    grouping_cols,
    uniques,
) -> list[Patch]:
    """
    Plot a scatter plot and LDA colour-coded using our different labels.

    Returns handles for plotting
    """
    unique_labels = np.unique(lda_labels)
    assert len(unique_labels) == len(colours), f"{len(unique_labels)=}, {len(colours)=}"

    # Grid for this pair
    xmin, xmax = x_coeffs.min(), x_coeffs.max()
    ymin, ymax = y_coeffs.min(), y_coeffs.max()
    xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    pretty_labels = {
        code: ", ".join(f"{col}={val}" for col, val in zip(grouping_cols, uniq))
        for code, uniq in enumerate(uniques)
    }

    handles = []
    for label, colour in zip(unique_labels, colours):
        mask = lda_labels == label
        if mask.sum() < 2:
            raise ValueError("KDE needs at least 2 points")

        data = np.vstack([x_coeffs[mask], y_coeffs[mask]])
        kde = gaussian_kde(data)
        f = np.reshape(kde(positions).T, xx.shape)

        axis.contourf(xx, yy, f, levels=15, cmap=clear2colour_cmap(colour))
        n = int(np.sum(mask))
        handles.append(Patch(color=colour, label=f"{pretty_labels[label]}, {n=}"))

        axis.scatter(*data, color=colour, s=5, marker="s", linewidth=0.5, edgecolor="k")

    return handles


def _kde_scatter(uniques, lda_coeffs, lda_labels, colours, grouping_cols):
    n_axes = lda_coeffs.shape[1] - 1
    fig, axes = plt.subplots(1, n_axes, figsize=(5 * n_axes, 5))
    if n_axes == 1:
        axes = [axes]

    for i, axis in enumerate(axes):
        handles = _plot_kde_scatter(
            axis,
            lda_coeffs[:, i],
            lda_coeffs[:, i + 1],
            lda_labels,
            colours,
            grouping_cols,
            uniques,
        )

    handles = [
        Patch(color=h.get_facecolor(), label=f"{h.get_label()}") for h in handles
    ]
    axes[0].legend(handles=handles)


def _plot_extrema(coeffs: np.ndarray, paths: pd.Series):
    """
    Show some scale images on the scatter plots
    """
    n_axes = coeffs.shape[1] - 1
    fig, axes = plt.subplots(1, n_axes, figsize=(4 * n_axes, 4))
    if n_axes == 1:
        axes = [axes]

    # directions - we'll dot with these to get the extrema
    angles = np.deg2rad(np.arange(0, 360, 22.5))
    dirns = np.stack([np.cos(angles), np.sin(angles)], axis=1)

    for i, axis in enumerate(axes):
        co_ords = np.stack([coeffs[:, i], coeffs[:, i + 1]])
        axis.scatter(*co_ords, s=2)

        scores = np.linalg.matmul(co_ords.T, dirns.T)
        extrema_idx = scores.argmax(axis=0)

        extrema_locs = co_ords.T[extrema_idx]
        extrema_paths = paths.values[extrema_idx]

        extrema_imgs = [tifffile.imread(f)[::20, ::20] for f in extrema_paths]

        for loc, img in zip(extrema_locs, extrema_imgs, strict=True):
            ab = AnnotationBbox(
                OffsetImage(img, zoom=0.3, cmap=clear2colour_cmap("k")),
                loc,
                pad=0.05,
                frameon=True,
                bboxprops={"edgecolor": "k", "linewidth": 1},
            )
            axis.add_artist(ab)


def plot_lda(
    df: pd.DataFrame,
    coeffs: np.ndarray,
    grouping_cols: str | list[str],
    colours: list[str],
) -> None:
    """
    Perform LDA on a dataframe using the values in columns.

    Requires the coeffs to be in the same order as the df columns
    """
    assert len(df) == len(coeffs), f"{len(df)=} {coeffs.shape=}"

    # First, get an encoding for the grouping categories and their labels
    if isinstance(grouping_cols, str):
        grouping_cols = [grouping_cols]
    combos = df[grouping_cols].apply(lambda row: tuple(row.values), axis=1)
    lda_labels, uniques = pd.factorize(combos)

    n_pcs = 10
    pipeline = Pipeline(
        [
            ("pca", PCA(n_components=n_pcs)),
            ("scale", StandardScaler()),
            ("lda", LinearDiscriminantAnalysis()),
        ]
    )

    cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
    scores = cross_val_score(
        pipeline, coeffs, lda_labels, cv=cv, scoring="balanced_accuracy"
    )

    # Plot the contribution from each EFA coefficient to the PCA
    pipeline.fit(coeffs, lda_labels)
    _plot_vectors(pipeline.named_steps["pca"].components_.T)
    plt.gcf().suptitle("PCs")

    # Then plot KDE + scatter plots
    _kde_scatter(
        uniques, pipeline.transform(coeffs), lda_labels, colours, grouping_cols
    )

    # Plot a heatmap of LDA vectors in terms of PCA coefficients
    lda = pipeline.named_steps["lda"]
    _heatmap(lda.scalings_)

    # Plot extrema
    _plot_extrema(pipeline.transform(coeffs), df["path"])

    print(f"Crossval score: {scores.mean():.3f}" "\u00b1" f"{scores.std():.3f}")
    print(classification_report(
        lda_labels, pipeline.predict(coeffs), target_names=[str(u) for u in uniques]
    ))

In [None]:
keep = df["sex"] != "?"
plot_lda(
    df[keep],
    coeffs[keep],
    ["sex", "age"],
    colours=[
        "royalblue",
        "cornflowerblue",
        "brown",
        "red",
        "blue",
        "indianred",
        "lightcoral",
    ],
)


In [None]:
keep = (df["sex"] == "M") & (df["age"] == 18)
plot_lda(
    df[keep],
    coeffs[keep],
    ["growth", "mutation"],
    colours=["indianred", "lightblue", "red", "blue"],
)
plt.gcf().suptitle("M18 only")

In [None]:
keep = (df["mutation"] == "WT") & (df["sex"] == "F")
plot_lda(df[keep], coeffs[keep], "age", colours=["indianred", "lightcoral", "red"])
plt.gcf().suptitle("F WTs")

In [None]:
keep = (df["mutation"] == "WT") & (df["sex"] == "F")
plot_lda(df[keep], coeffs[keep], "age", colours=["indianred", "lightcoral", "red"])
plt.gcf().suptitle("M WTs")

In [None]:
copy_df = df.copy()

rng = np.random.default_rng()
copy_df["random"] = rng.integers(0, 2, size=len(df))
copy_df["random2"] = rng.integers(0, 4, size=len(df))

plot_lda(
    copy_df,
    coeffs,
    ["random", "random2"],
    colours=list(colors.TABLEAU_COLORS.keys())[:8],
)
plt.gcf().suptitle("Spurious groupings are unlikely with the entire dataset")

plot_lda(copy_df, coeffs, "random2", colours=list(colors.TABLEAU_COLORS.keys())[:4])
keep = (df["sex"] == "M") & (df["age"] == 18)
plot_lda(
    copy_df[keep],
    coeffs[keep],
    ["random", "random2"],
    colours=list(colors.TABLEAU_COLORS.keys())[:8],
)
plt.gcf().suptitle(
    "But if we only use a subset, it becomes quite easy to separate them based on random noise"
)

plot_lda(
    copy_df[keep],
    coeffs[keep],
    "random2",
    colours=list(colors.TABLEAU_COLORS.keys())[:4],
)