In [None]:
import pandas as pd

features_df = pd.read_csv("features.csv", index_col=0)

features_df.head()

In [None]:
"""
Add a column describing the mutation status (wt/het/hom/mosaic)
"""

from fishjaw.inference import feature_selection

features_df = feature_selection.add_metadata_cols(features_df)
features_df.head()

In [None]:
"""
Remove features with zero variance
"""

null_variance_cols = features_df["Features"].columns[features_df["Features"].var() == 0]
features_df.drop(columns=null_variance_cols, inplace=True, level=1)

print(f"Dropped:\n\t", ", ".join(null_variance_cols))
features_df.head()

In [None]:
"""
Plot correlations
"""

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

corr = features_df["Features"].corr()
sns.heatmap(corr, vmin=-1, vmax=1, cmap="seismic")

c = np.abs(corr.to_numpy().flat)
c[c == 1.0] = np.nan

fig, axis = plt.subplots()
axis.hist(c, bins=100)
axis.set_title(r"$\left|\mathrm{Correlations}\right|$")

In [None]:
"""
Drop highly correlated features
"""

from typing import Iterable, Tuple, List, Optional


def drop_correlated_features(
    df: pd.DataFrame,
    threshold: float = 0.8,
    protected: Optional[Iterable[str]] = None,
    prefer: str = "lower_variance",  # or "higher_variance" or "mean_corr"
) -> Tuple[List[str], List[str]]:
    """
    Greedily drop a minimal-ish set of columns so that all remaining
    pairwise absolute correlations are <= threshold.

    - protected: columns never to drop (will raise if impossible).
    - prefer: tie-breaker when choosing what to drop among highly connected nodes.
    """
    if not 0 <= threshold <= 1:
        raise ValueError("threshold must be in [0, 1]")

    prot = set(protected or [])

    # Absolute correlation matrix
    corr = df.corr().abs()
    # Remove self-correlation to simplify logic
    np.fill_diagonal(corr.values, 0.0)
    # Replace NaNs with 0 (e.g., constant columns). Ideally drop NaNs beforehand.
    corr = corr.fillna(0.0)

    to_drop: List[str] = []
    remaining = corr.index.tolist()

    while True:
        # Edges above threshold
        mask = corr > threshold
        if not mask.values.any():
            break

        # Degree = number of correlations above threshold
        deg = mask.sum(axis=1)

        # Candidate nodes with max degree
        max_deg = deg.max()
        cand = deg[deg == max_deg].index.tolist()

        # Apply tie-breaker
        if prefer == "lower_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmin()
        elif prefer == "higher_variance":
            var = df[cand].var(numeric_only=True)
            pick = var.idxmax()
        elif prefer == "mean_corr":
            mc = corr.loc[cand].mean(axis=1)
            pick = mc.idxmax()
        else:
            pick = cand[0]  # deterministic order if possible

        if pick in prot:
            # If protected is involved in edges, try dropping the most offending non-protected neighbor
            # Choose neighbor with largest correlation to the protected node
            neighbors = corr.columns[mask.loc[pick]]
            neighbors = [n for n in neighbors if n not in prot]
            if not neighbors:
                raise RuntimeError(
                    f"Cannot satisfy threshold={threshold} without dropping protected feature '{pick}'"
                )
            # Choose neighbor with highest correlation to the protected pick
            pick = corr.loc[pick, neighbors].idxmax()

        # Drop the picked column/row from the working correlation matrix
        to_drop.append(pick)
        corr = corr.drop(index=pick, columns=pick)
        remaining.remove(pick)

    return remaining, to_drop

In [None]:
kept, dropped = drop_correlated_features(features_df["Features"], threshold=0.8)
# Keep only 'kept'
features_df.drop(columns=dropped, level=1, inplace=True)
print(f"Dropped {len(dropped)} cols:\n\t", ", ".join(dropped))

features_df.head()

In [None]:
corr = features_df["Features"].corr()
sns.heatmap(corr, vmin=-1, vmax=1, cmap="seismic")

c = np.abs(corr.to_numpy().flat)
c[c == 1.0] = np.nan

fig, axis = plt.subplots()
axis.hist(c, bins=100)
axis.set_title(r"$\left|\mathrm{Correlations}\right|$")

In [None]:
import numpy as np

X = features_df["Features"]
mu = X.mean()
sigma = X.std(ddof=0)

features_df["Features"] = (X - mu) / sigma.replace(0.0, np.nan)
features_df.head()

In [None]:
"""Mann-Whitney U"""

from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests


def feature_tests(df):
    X = df["Features"]
    y = df[("Metadata", "genotype")]
    results = []
    for col in X.columns:
        group1 = X.loc[y == "wt", col]  # WT
        group2 = X.loc[y != "wt", col]  # mutant
        stat, p = mannwhitneyu(group1, group2, alternative="two-sided")
        auc = np.mean(
            [val > group2.median() for val in group1]
        )  # quick effect size proxy
        results.append((col, stat, p, auc))
    df = pd.DataFrame(results, columns=["feature", "U", "pval", "effect_size"])
    df["pval_adj"] = multipletests(df.pval, method="fdr_bh")[1]
    return df.sort_values("pval_adj")


results = feature_tests(features_df)
results

In [None]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(n_estimators=500, random_state=0)
rf.fit(features_df["Features"], features_df[("Metadata", "genotype")])
importances = pd.Series(rf.feature_importances_, index=features_df["Features"].columns)
print(importances.sort_values(ascending=False))


In [None]:
# Plot the distributions of the top three mann whitney score variables
X = features_df["Features"]
y = features_df[("Metadata", "genotype")]
y = y.rename("genotype")

# Long form for FacetGrid
df_long = X.assign(genotype=y).melt(
    id_vars="genotype", var_name="feature", value_name="value"
)

g = sns.FacetGrid(
    df_long,
    col="feature",
    col_wrap=4,
    height=3.0,
    sharex=False,
    sharey=False,
    hue="genotype",
)
g.map_dataframe(sns.kdeplot, x="value", common_norm=False)
g.add_legend()
g.set_titles("{col_name}")
plt.tight_layout()
plt.show()

In [None]:
sns.pairplot(features_df, hue=("Metadata", "genotype"))

In [None]:
# --- Example data ---
# Replace these with your MWU and RF results
# MWU results: columns=['feature', 'pval_adj']
# RF importances: series indexed by feature
mwu_df = results[["feature", "pval_adj"]].copy()
rf_importances = importances.copy()  # from your RF series

# Merge MWU and RF
summary = mwu_df.set_index("feature").join(rf_importances.rename("rf_importance"))
summary = summary.fillna(0)  # in case some features missing

# Compute -log10 FDR for MWU
summary["minus_log10_fdr"] = -np.log10(summary["pval_adj"])

# Sort features by MWU significance (or RF importance)
summary = summary.sort_values("minus_log10_fdr", ascending=False)

# --- Plotting ---
fig, ax1 = plt.subplots(figsize=(10, 6))

# Bar plot: MWU significance
ax1.bar(summary.index, summary["minus_log10_fdr"], color="skyblue", label="-log10(FDR)")
ax1.set_ylabel("-log10(FDR) Mann-Whitney U", color="blue")
ax1.tick_params(axis="y", labelcolor="blue")
ax1.set_xticklabels(summary.index, rotation=45, ha="right")

# Overlay: RF importance as line
ax2 = ax1.twinx()
ax2.plot(
    summary.index,
    summary["rf_importance"],
    color="red",
    marker="o",
    label="Random Forest importance",
)
ax2.set_ylabel("RF importance", color="red")
ax2.tick_params(axis="y", labelcolor="red")

# Add legends
ax1.legend(loc="upper right")
ax2.legend(loc="right")

plt.title("Feature importance: Mann-Whitney U vs Random Forest")
plt.tight_layout()
plt.show()

In [None]:
"""
PCA and biplot to get an idea of what good descriptors might be
"""

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# --- Standardize features ---
X = features_df["Features"].copy()
y = (features_df[("Metadata", "genotype")] == "wt").copy()
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# --- Run PCA ---
pca = PCA(n_components=2)  # first two PCs for biplot
X_pca = pca.fit_transform(X_scaled)

# --- Explained variance ---
explained_var = pca.explained_variance_ratio_
print(f"PC1 explains {explained_var[0]*100:.1f}% of variance")
print(f"PC2 explains {explained_var[1]*100:.1f}% of variance")

# --- Biplot ---
plt.figure(figsize=(10, 8))

# Scatter plot of samples
colors = ["blue" if label == 0 else "red" for label in y]
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=colors, alpha=0.6)
plt.xlabel(f"PC1 ({explained_var[0]*100:.1f}%)")
plt.ylabel(f"PC2 ({explained_var[1]*100:.1f}%)")

# Plot feature vectors (loadings)
loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
for i, feature in enumerate(X.columns):
    plt.arrow(0, 0, loadings[i, 0], loadings[i, 1], color="black", alpha=0.7)
    plt.text(loadings[i, 0] * 1.15, loadings[i, 1] * 1.15, feature, fontsize=9)

plt.title("PCA Biplot of Radiomics Features")
plt.grid(True)
plt.axhline(0, color="grey", linewidth=0.8)
plt.axvline(0, color="grey", linewidth=0.8)
plt.tight_layout()
plt.show()

# --- Optional: Plot feature importance from PCA (absolute loadings) ---
feature_importance = pd.Series(
    np.abs(pca.components_[0]) + np.abs(pca.components_[1]), index=X.columns
)
feature_importance = feature_importance.sort_values(ascending=False)

plt.figure(figsize=(10, 5))
feature_importance.plot(kind="bar", color="purple")
plt.ylabel("Sum of absolute loadings (PC1 + PC2)")
plt.title("Feature contribution to first two principal components")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
loadings = pd.DataFrame(
    pca.components_.T,
    index=features_df["Features"].columns,
    columns=[f"PC{i+1}" for i in range(pca.n_components_)],
)

# --- Plot top features per PC ---
num_top_features = 10  # number of top contributing features to show per PC
num_pcs_to_plot = 2  # how many PCs to visualize

fig, axes = plt.subplots(
    num_pcs_to_plot, 1, figsize=(10, 3 * num_pcs_to_plot), sharex=True
)

for i in range(num_pcs_to_plot):
    pc = f"PC{i+1}"
    # Get absolute loadings sorted
    sorted_loadings = loadings[pc].abs().sort_values(ascending=False)
    top_features = sorted_loadings.head(num_top_features)

    # Barplot
    axes[i].bar(top_features.index, top_features.values, color="teal")
    axes[i].set_ylabel(f"{pc} |loading|")
    axes[i].set_title(f"Top {num_top_features} features contributing to {pc}")
    axes[i].set_xticklabels(labels=top_features.index, rotation=45, ha="right")

plt.tight_layout()
plt.show()