In [None]:
from fishjaw.util import files

in_dir = files.script_out_dir() / "jaw_segmentations"
img_in_dir = in_dir / "imgs"
mask_in_dir = in_dir / "masks"

img_paths = sorted(list(img_in_dir.glob("*.tif")))
mask_paths = sorted(list(mask_in_dir.glob("*.tif")))

In [None]:
%%capture
from fishjaw.inference import read

# Exclude the contrast enhanced and bad segmentations
exclude = [
    read.is_excluded(
        read.fish_number(f), exclude_train_data=False, exclude_unknown_age=False
    )
    for f in img_paths
]

mask_paths = [m for m, e in zip(mask_paths, exclude) if not e]
img_paths = [i for i, e in zip(img_paths, exclude) if not e]

In [None]:
# Read in the masks
import tifffile
from tqdm.notebook import tqdm

masks = [tifffile.imread(f) for f in tqdm(mask_paths)]

In [None]:
# Read in the greyscale
imgs = [tifffile.imread(f) for f in tqdm(img_paths)]

In [None]:
# Get the metadata

metadata = [read.metadata(read.fish_number(f)) for f in img_paths]

In [None]:
from radiomics import featureextractor
import SimpleITK as sitk
import pandas as pd
import numpy as np

params_file = "radiomics_config.yaml"
extractor = featureextractor.RadiomicsFeatureExtractor(params_file)

cases = [
    (img_array, mask_array, m)
    for (img_array, mask_array, m) in zip(imgs, masks, metadata)
]

features_list = []
for img_array, mask_array, mdata in tqdm(cases):
    # Convert numpy arrays to SimpleITK images
    img = sitk.GetImageFromArray(img_array)
    mask = sitk.GetImageFromArray(mask_array.astype(np.uint8))

    img.SetSpacing(mdata.voxel_size)
    mask.SetSpacing(mdata.voxel_size)

    # Extract features
    result = extractor.execute(img, mask)

    # Keep only numeric features
    result_clean = {
        k: v for k, v in result.items() if isinstance(v, (int, float, np.ndarray))
    }
    result_clean["ID"] = mdata.n

    features_list.append(result_clean)

In [None]:
features_df = pd.DataFrame(features_list).set_index("ID")
print(features_df.shape)
features_df.head()

In [None]:
features_df.to_csv("features.csv")

# If reading from disk, start here...

In [None]:
import pandas as pd
features_df = pd.read_csv("features.csv", index_col=0)

features_df.head()

In [None]:
"""
Add a column describing the mutation status (wt/het/hom/mosaic)
"""
from fishjaw.inference import feature_selection

features_df = feature_selection.add_metadata_cols(features_df)
features_df.head()

In [None]:
"""
Remove features with zero variance
"""

null_variance_cols = features_df["Features"].columns[features_df["Features"].var() == 0]
features_df.drop(columns=null_variance_cols, inplace=True, level=1)

print(f"Dropped:\n\t", ", ".join(null_variance_cols))
features_df.head()

In [None]:
"""
Plot correlations
"""

import seaborn as sns

corr = features_df["Features"].corr()
sns.heatmap(corr, vmin=-1, vmax=1, cmap="seismic")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

c = np.abs(corr.to_numpy().flat)
c[c == 1.0] = np.nan

fig, axis = plt.subplots()
axis.hist(c, bins=100)
axis.set_title(r"$\left|\mathrm{Correlations}\right|$")

In [None]:
"""
Drop highly correlated features
"""

from typing import Iterable, Tuple, List, Optional


def drop_correlated_features(
    df: pd.DataFrame,
    threshold: float = 0.8,
    protected: Optional[Iterable[str]] = None,
    prefer: str = "lower_variance",  # or "higher_variance" or "mean_corr"
) -> Tuple[List[str], List[str]]:
    """
    Greedily drop a minimal-ish set of columns so that all remaining
    pairwise absolute correlations are <= threshold.

    - protected: columns never to drop (will raise if impossible).
    - prefer: tie-breaker when choosing what to drop among highly connected nodes.
    """
    if not 0 <= threshold <= 1:
        raise ValueError("threshold must be in [0, 1]")

    prot = set(protected or [])

    # Absolute correlation matrix
    corr = df.corr().abs()
    # Remove self-correlation to simplify logic
    np.fill_diagonal(corr.values, 0.0)
    # Replace NaNs with 0 (e.g., constant columns). Ideally drop NaNs beforehand.
    corr = corr.fillna(0.0)

    to_drop: List[str] = []
    remaining = corr.index.tolist()

    while True:
        # Edges above threshold
        mask = corr > threshold
        if not mask.values.any():
            break

        # Degree = number of correlations above threshold
        deg = mask.sum(axis=1)

        # Candidate nodes with max degree
        max_deg = deg.max()
        cand = deg[deg == max_deg].index.tolist()

        # Apply tie-breaker
        if prefer == "lower_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmin()
        elif prefer == "higher_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmax()
        elif prefer == "mean_corr":
            mc = corr.loc[cand].mean(axis=1)
            pick = mc.idxmax()
        else:
            pick = cand[0]  # deterministic order if possible

        if pick in prot:
            # If protected is involved in edges, try dropping the most offending non-protected neighbor
            # Choose neighbor with largest correlation to the protected node
            neighbors = corr.columns[mask.loc[pick]]
            neighbors = [n for n in neighbors if n not in prot]
            if not neighbors:
                raise RuntimeError(
                    f"Cannot satisfy threshold={threshold} without dropping protected feature '{pick}'"
                )
            # Choose neighbor with highest correlation to the protected pick
            pick = corr.loc[pick, neighbors].idxmax()

        # Drop the picked column/row from the working correlation matrix
        to_drop.append(pick)
        corr = corr.drop(index=pick, columns=pick)
        remaining.remove(pick)

    return remaining, to_drop

In [None]:
kept, dropped = drop_correlated_features(features_df["Features"], threshold=0.8)
# Keep only 'kept'
features_df.drop(columns=dropped, level=1, inplace=True)
print(f"Dropped {len(dropped)} cols:\n\t", ", ".join(dropped))

features_df.head()

In [None]:
"""
Z-normalise the remaining features
"""

In [None]:
"""
Show variance of the remaining features
"""


In [None]:
import seaborn as sns
sns.pairplot(features_df)

In [None]:
"""
PCA and biplot to get an idea of what good descriptors might be
"""