In [None]:
import pandas as pd

features_df = pd.read_csv("features.csv", index_col=0)

features_df.head()

In [None]:
"""
Add a column describing the mutation status (wt/het/hom/mosaic)
"""

from fishjaw.inference import feature_selection

features_df = feature_selection.add_metadata_cols(features_df)
features_df.head()

In [None]:
"""
Remove features with zero variance
"""

null_variance_cols = features_df["Features"].columns[features_df["Features"].var() == 0]
features_df.drop(columns=null_variance_cols, inplace=True, level=1)

print(f"Dropped:\n\t", ", ".join(null_variance_cols))
features_df.head()

In [None]:
"""
Plot correlations
"""

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

corr = features_df["Features"].corr()
sns.heatmap(corr, vmin=-1, vmax=1, cmap="seismic")

c = np.abs(corr.to_numpy().flat)
c[c == 1.0] = np.nan

fig, axis = plt.subplots()
axis.hist(c, bins=100)
axis.set_title(r"$\left|\mathrm{Correlations}\right|$")

In [None]:
"""
Drop highly correlated features
"""

from typing import Iterable, Tuple, List, Optional


def drop_correlated_features(
    df: pd.DataFrame,
    threshold: float = 0.8,
    protected: Optional[Iterable[str]] = None,
    prefer: str = "lower_variance",  # or "higher_variance" or "mean_corr"
) -> Tuple[List[str], List[str]]:
    """
    Greedily drop a minimal-ish set of columns so that all remaining
    pairwise absolute correlations are <= threshold.

    - protected: columns never to drop (will raise if impossible).
    - prefer: tie-breaker when choosing what to drop among highly connected nodes.
    """
    if not 0 <= threshold <= 1:
        raise ValueError("threshold must be in [0, 1]")

    prot = set(protected or [])

    # Absolute correlation matrix
    corr = df.corr().abs()
    # Remove self-correlation to simplify logic
    np.fill_diagonal(corr.values, 0.0)
    # Replace NaNs with 0 (e.g., constant columns). Ideally drop NaNs beforehand.
    corr = corr.fillna(0.0)

    to_drop: List[str] = []
    remaining = corr.index.tolist()

    while True:
        # Edges above threshold
        mask = corr > threshold
        if not mask.values.any():
            break

        # Degree = number of correlations above threshold
        deg = mask.sum(axis=1)

        # Candidate nodes with max degree
        max_deg = deg.max()
        cand = deg[deg == max_deg].index.tolist()

        # Apply tie-breaker
        if prefer == "lower_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmin()
        elif prefer == "higher_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmax()
        elif prefer == "mean_corr":
            mc = corr.loc[cand].mean(axis=1)
            pick = mc.idxmax()
        else:
            pick = cand[0]  # deterministic order if possible

        if pick in prot:
            # If protected is involved in edges, try dropping the most offending non-protected neighbor
            # Choose neighbor with largest correlation to the protected node
            neighbors = corr.columns[mask.loc[pick]]
            neighbors = [n for n in neighbors if n not in prot]
            if not neighbors:
                raise RuntimeError(
                    f"Cannot satisfy threshold={threshold} without dropping protected feature '{pick}'"
                )
            # Choose neighbor with highest correlation to the protected pick
            pick = corr.loc[pick, neighbors].idxmax()

        # Drop the picked column/row from the working correlation matrix
        to_drop.append(pick)
        corr = corr.drop(index=pick, columns=pick)
        remaining.remove(pick)

    return remaining, to_drop

In [None]:
kept, dropped = drop_correlated_features(features_df["Features"], threshold=0.8)
# Keep only 'kept'
features_df.drop(columns=dropped, level=1, inplace=True)
print(f"Dropped {len(dropped)} cols:\n\t", ", ".join(dropped))

features_df.head()

In [None]:
corr = features_df["Features"].corr()
sns.heatmap(corr, vmin=-1, vmax=1, cmap="seismic")

c = np.abs(corr.to_numpy().flat)
c[c == 1.0] = np.nan

fig, axis = plt.subplots()
axis.hist(c, bins=100)
axis.set_title(r"$\left|\mathrm{Correlations}\right|$")

In [None]:
"""Z-normalise features"""

import numpy as np


def normalise(df):
    X = df["Features"]
    mu = X.mean()
    sigma = X.std(ddof=0)

    df["Features"] = (X - mu) / sigma.replace(0.0, np.nan)
    return df


features_df = normalise(features_df)
features_df.head()

In [None]:
"""Mann-Whitney U"""

from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

target_col = ("Metadata", "genotype")

def feature_tests(df, target_col, target_val):
    """
    Mann-Whitney U test for each feature between WT and mutant fish.
    """
    X = df["Features"]
    y = df[target_col]

    results = []
    for col in X.columns:
        group1 = X.loc[y == target_val, col]  # WT
        group2 = X.loc[y != target_val, col]  # mutant
        stat, p = mannwhitneyu(group1, group2, alternative="two-sided")
        auc = np.mean(
            [val > group2.median() for val in group1]
        )  # quick effect size proxy
        results.append((col, stat, p, auc))
    df = pd.DataFrame(results, columns=["feature", "U", "pval", "effect_size"])
    df["pval_adj"] = multipletests(df.pval, method="fdr_bh")[1]
    return df.sort_values("pval_adj")


mwu_results = feature_tests(features_df, target_col, "wt")
mwu_results.head(5)

In [None]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split


def rf_tests(df, target_col):
    """
    Random forest feature importance.
    """
    X = df["Features"]
    y = df[target_col]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=42
    )

    rf = GradientBoostingClassifier(
        n_estimators=100,
        random_state=0,
        max_depth=2,
        learning_rate=0.1,
        min_samples_leaf=10,
        min_samples_split=10,
        subsample=0.9,
        criterion="friedman_mse",
        max_features=3,
    )

    rf.fit(X_train, y_train)

    # Check the score - should be pretty good
    print("Train score: ", rf.score(X_train, y_train))
    print("Test score: ", rf.score(X_test, y_test))

    importances = pd.Series(rf.feature_importances_, index=X_train.columns)

    # Conver to df to make notebook output nicer
    return pd.DataFrame(importances.sort_values(ascending=False))


rf_results = rf_tests(features_df, target_col)
rf_results.head(5)

In [None]:
"""
Plot 1d distributions of top features
"""

import textwrap


def plot_feature_dists(df, mwu_results, rf_results, target_col):
    features = list(
        set(
            (
                mwu_results.head(5)["feature"].tolist()
                + rf_results.head(5).index.tolist()
            )
        )
    )

    plot_df = df[[("Features", f) for f in features] + [target_col]]

    pairgrid = sns.pairplot(
        plot_df,
        hue=target_col,
        plot_kws={"alpha": 0.5, "s": 5},
        diag_kind="kde",
        diag_kws=dict(common_norm=False),
    )

    # Set x and y labels to the feature names only
    for ax in pairgrid.axes.flatten():
        if ax is not None:
            if xlabel := ax.get_xlabel():
                ax.set_xlabel(textwrap.fill(xlabel.split(",")[1].strip(), 20))

            if ylabel := ax.get_ylabel():
                ax.set_ylabel(textwrap.fill(ax.get_ylabel().split(",")[1].strip(), 20))


plot_feature_dists(features_df, mwu_results, rf_results, target_col=target_col)

Let's repeat these, but selecting by age...
====
We will want to only take wildtypes for this - I'll take those with a genotype of "wt" or "het" (these will probably be wild-enough-type...)

We'll select young and older fish as our age groups - see below for the groupings.

In [None]:
"""
Keep only wildtype-ish
"""
age_df = features_df[features_df[("Metadata", "genotype")].isin(["wt", "het"])].copy()
age_df.head()

In [None]:
"""
Split by age
"""

def age_grouper(age):
    age = int(age)
    if age == -1:
        return np.nan

    if age <= 12:
        return 0
    if age >= 30:
        return 1

    return np.nan


age_df.loc[:, ("Metadata", "age_group")] = age_df[("Metadata", "age")].map(
    age_grouper
)

age_df = age_df.dropna(subset=[("Metadata", "age_group")])
age_df[("Metadata", "age_group")].value_counts(dropna=False)

In [None]:
"""
Repeat the above feature selection steps, but for age groups
"""

target_col = ("Metadata", "age_group")
age_df = normalise(age_df)

mwu_results = feature_tests(age_df, target_col, 0)
rf_results = rf_tests(age_df, target_col)

display(mwu_results)
display(rf_results)

plot_feature_dists(age_df, mwu_results, rf_results, target_col)