In [None]:
import scanpy as sc
import pandas as pd

In [None]:
# check if mrmr is installed
try:
    from mrmr import mrmr_classif
except ImportError:
    !pip install mrmr-selection
    from mrmr import mrmr_classif

In [None]:
adata_path = '../data/spatial_single_cell_KS_adata.h5ad'
adata = sc.read_h5ad(adata_path)

In [None]:
# -----------------------------------------------------------------------------
# Construct combined cell-level metadata and gene expression matrix
# -----------------------------------------------------------------------------

# Convert the expression matrix from the AnnData object into a dense
# pandas DataFrame. `adata.X` is often a sparse matrix; `todense()` ensures
# compatibility with standard pandas operations.
# Rows correspond to cells; columns correspond to genes (adata.var_names).
expression_data = pd.DataFrame(
    adata.X.todense(),
    columns=adata.var_names,
)

# Select key cell-level annotations (metadata) from `adata.obs`.
cell_annotation_columns = [
    "broad_cell_types",
    "niche_with_tumor_proximity",
    "path_block_core",
]
cell_info = adata.obs[cell_annotation_columns].copy()

# Concatenate cell-level metadata with the expression matrix so that each row
# contains both annotations and gene-level expression for a single cell.
# `reset_index(drop=True)` is used on both DataFrames to ensure aligned indices
# (row order is preserved, but original index labels are discarded).
combined_data = pd.concat(
    [
        cell_info.reset_index(drop=True),
        expression_data.reset_index(drop=True),
    ],
    axis=1,
)


In [None]:
# -----------------------------------------------------------------------------
# Compute per-core average gene expression features
# -----------------------------------------------------------------------------

import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from warnings import simplefilter

# Suppress pandas performance warnings (optional, keeps notebooks clean)
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

# Extract relevant metadata columns from `combined_data`
cell_types = combined_data["broad_cell_types"]
niches = combined_data["niche_with_tumor_proximity"]
cores = combined_data["path_block_core"]

# List of dictionaries to collect per-core mean expression data
av_expression = []

# Iterate through each pathology block core
for core in tqdm(cores.unique(), desc="Computing per-core expression"):
    # Dictionary to hold features for the current core
    result_expression = {
        "path_block_core": core
    }

    # Select all cells belonging to this core
    mask = combined_data["path_block_core"] == core

    if np.sum(mask) > 0:
        # Compute mean expression across all cells within this core
        # Columns 3+ correspond to gene expression values.
        # Implements the per-core averaging step described in the paper
        # (e.g., Section X.Y: “Core-level expression summarization”).
        mean_expression = combined_data.loc[mask].iloc[:, 3:].mean()

        # Convert the Series to a two-column DataFrame
        mean_expression = mean_expression.reset_index()
        mean_expression.columns = ["gene", "mean_expression"]

        # Append each gene’s mean expression to the core’s feature dict
        for gene, mean_expr in mean_expression.itertuples(index=False):
            feature_name = f"{gene}"
            result_expression[feature_name] = mean_expr

    # Retrieve and attach the pathology stage for this core
    # Assumes exactly one unique stage per core.
    stage = adata.obs[adata.obs["path_block_core"] == core]["Stage"].unique()[0]
    result_expression["Stage"] = stage

    # Append the per-core result to the overall list
    av_expression.append(result_expression)

# Convert the list of dicts into a DataFrame:
# - Each row represents a pathology block core
# - Columns represent genes (average expression) and the core’s stage label
avx_features_1 = pd.DataFrame(av_expression)


In [None]:
# -----------------------------------------------------------------------------
# Compute per-core, per-cell-type average gene expression features
# -----------------------------------------------------------------------------

import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from warnings import simplefilter

# Suppress pandas PerformanceWarning that can arise when operating on large
# DataFrames with mixed dtypes. This is purely to keep notebook output clean.
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

# Convenience references to key annotation columns.
cell_types = combined_data["broad_cell_types"]
cores = combined_data["path_block_core"]

# Initialize a list of dicts; each dict will store features for one core.
# Later converted to a DataFrame (`avx_features_2`).
av_expression = []

# Iterate over each unique pathology block core.
for core in tqdm(cores.unique(), desc="Computing core-level expression"):
    # Dictionary to collect feature values for this core.
    result_expression = {
        "path_block_core": core
    }

    # For each core, iterate over all unique cell types.
    for cell_type in cell_types.unique():
        # Mask selecting rows belonging to this (core, cell_type) group.
        mask = (
            (combined_data["broad_cell_types"] == cell_type)
            & (combined_data["path_block_core"] == core)
        )

        if np.sum(mask) > 0:
            # Compute mean gene expression across cells in this group.
            # Columns from index 3 onward are assumed to be gene expression.
            # Implements the averaging step used for core-level features in
            # the paper (e.g., Section X.Y where core-level gene signatures
            # are defined).
            mean_expression = combined_data.loc[mask].iloc[:, 3:].mean()

            # Convert the Series to a two-column DataFrame:
            # "gene" and "mean_expression".
            mean_expression = mean_expression.reset_index()
            mean_expression.columns = ["gene", "mean_expression"]
        else:
            # No cells of this type in this core.
            counts = 0  # Placeholder; not used elsewhere.

        # Attach mean expression values to the result dict, with a distinct
        # feature name per (gene, cell_type), e.g., "GENE1_T_cell".
        for gene, mean_expr in mean_expression.itertuples(index=False):
            feature_name = f"{gene}_{cell_type}"
            result_expression[feature_name] = mean_expr

    # Attach the core-level "Stage" label from `adata.obs`.
    # Assumes exactly one unique Stage per core.
    stage = adata.obs[adata.obs["path_block_core"] == core]["Stage"].unique()[0]
    result_expression["Stage"] = stage

    # Store results for this core.
    av_expression.append(result_expression)

# Final DataFrame: one row per core, columns:
# - "path_block_core"
# - "Stage"
# - per-(gene, cell_type) mean expression features.
avx_features_2 = pd.DataFrame(av_expression)


In [None]:
# -----------------------------------------------------------------------------
# Compute per-core, per-cell-type, per-niche average gene expression features
# -----------------------------------------------------------------------------

import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from warnings import simplefilter

# Suppress non-critical performance warnings from pandas
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

# Extract key annotation columns for convenience
cell_types = combined_data["broad_cell_types"]
niches = combined_data["niche_with_tumor_proximity"]
cores = combined_data["path_block_core"]

# Container for per-core feature dictionaries
av_expression = []

# Loop over all pathology block cores
for core in tqdm(cores.unique(), desc="Computing celltype-niche-core averages"):
    # Initialize result dictionary for the current core
    result_expression = {"path_block_core": core}

    # Nested loop: iterate through all combinations of cell types and niches
    for cell_type in cell_types.unique():
        for niche in niches.unique():
            # Mask cells belonging to the current (core, cell_type, niche) subset
            mask = (
                (combined_data["broad_cell_types"] == cell_type)
                & (combined_data["path_block_core"] == core)
                & (combined_data["niche_with_tumor_proximity"] == niche)
            )

            if np.sum(mask) > 0:
                # Compute mean expression across all genes for this subset
                # Implements spatially stratified averaging, as described in
                # the paper (e.g., “Section X.Y: Niche-stratified expression profiles”).
                mean_expression = combined_data.loc[mask].iloc[:, 3:].mean()

                # Convert the mean Series to a DataFrame for iteration
                mean_expression = mean_expression.reset_index()
                mean_expression.columns = ["gene", "mean_expression"]
            else:
                # No cells in this combination
                counts = 0  # Placeholder variable, unused

            # Record each gene’s mean expression, namespaced by cell type and niche
            for gene, mean_expr in mean_expression.itertuples(index=False):
                feature_name = f"{gene}_{cell_type}_{niche}"
                result_expression[feature_name] = mean_expr

    # Retrieve the stage label corresponding to this core
    stage = adata.obs[adata.obs["path_block_core"] == core]["Stage"].unique()[0]
    result_expression["Stage"] = stage

    # Append the dictionary to the results list
    av_expression.append(result_expression)

# Combine all core-level dictionaries into a single DataFrame.
# Each row corresponds to one pathology block core.
# Columns include:
# - "path_block_core"
# - "Stage"
# - All per-(gene, cell_type, niche) average expression features.
avx_features_3 = pd.DataFrame(av_expression)


In [None]:
features_dict = {
    'per_core': avx_features_1,
    'per_core_celltype': avx_features_2,
    'per_core_celltype_niche': avx_features_3
}


In [None]:
import pickle
with open('../data/KS_features_per_core_celltype_niche.pkl', 'rb') as f:
    features_dict = pickle.load(f)

## Model training

In [None]:
feat_set_mapping = {
    'per_core': 'Bulk',
    'per_core_celltype': 'Single-cell',
    'per_core_celltype_niche': 'Spatial Single-cell'
}
model_name_mapping = {
    'l1': 'LASSO',
    'l2': 'Ridge',
    'xgb': 'XGBoost'
}
label_mapping = {
    'patch': 0,
    'plaque': 1,
    'nodular': 2
}

In [None]:
# -----------------------------------------------------------------------------
# Model training and evaluation with different feature sets and penalties
#
# This code evaluates three different feature sets (e.g., avx_features_1–3)
# using multiple classifiers (L1, L2 regularized linear models, and XGBoost)
# for multi-class prediction of pathological stage ("patch", "plaque", "nodular").
#
# The results are stored in:
#   - `compiled_results`: list of dictionaries summarizing CV performance
#   - `compiled_models`: dictionary mapping (feature_set, model_type) → model instance
#
# Each classifier is evaluated using 5-fold cross-validation, with performance
# measured via one-vs-rest ROC AUC (`roc_auc_ovr`). 
# -----------------------------------------------------------------------------

from sklearn.metrics import accuracy_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.linear_model import SGDClassifier
import xgboost as xgb
from collections import defaultdict
import numpy as np

# Containers for results and model references
compiled_results = []
compiled_models = defaultdict(dict)
# Iterate over each feature set (e.g., per-core, per-cell-type, etc.)
for feat_set, features_df in features_dict.items():
    # Copy feature set to avoid modifying the original
    cdata = features_df.copy().reset_index(drop=True)

    # Exclude control samples for supervised classification
    cdata = cdata[cdata["Stage"] != "control"]

    # Extract labels (pathological stage) and feature matrix
    labels = cdata["Stage"].values
    cdata.drop(columns=["Stage", "path_block_core"], inplace=True)
    y = np.array([label_mapping[label] for label in labels])

    # Evaluate three classifier variants: L1, L2, and XGBoost
    for classifier in ["l1", "l2", "xgb"]:
        results = {}

        if classifier == "xgb":
            # XGBoost model for multi-class softmax classification
            # Implements nonlinear baseline (Section X.Y: model comparison)
            clf = xgb.XGBClassifier(
                objective="multi:softmax",
                max_depth=6,
                learning_rate=0.5,
                subsample=0.8,
                colsample_bytree=0.8,
                num_class=3,
                eval_metric="auc",
                seed=42,
            )
        else:
            # Linear classifier with L1 (Lasso) or L2 (Ridge) regularization
            # Implements interpretable sparse models (Section X.Y)
            clf = SGDClassifier(
                loss="log_loss",          # Logistic regression objective
                penalty=classifier,          # Regularization type (L1 or L2)
                alpha=0.01,               # Regularization strength
                l1_ratio=0.5,             # Mix of L1/L2 (used if ElasticNet)
                max_iter=1000,            # Maximum optimization iterations
                tol=1e-3,                 # Convergence tolerance
                random_state=42,
            )

        # Perform 5-fold cross-validation using ROC-AUC as the metric
        # Each fold measures model ability to distinguish between stages.
        cv_scores = cross_val_score(clf, cdata, y, cv=5, scoring="roc_auc_ovr")

        # Record results for this (feature_set, model_type)
        results["feature_set"] = feat_set_mapping[feat_set]
        results["model"] = model_name_mapping[classifier]
        compiled_models[feat_set][classifier] = clf

        # Store individual CV fold scores
        for m, n in enumerate(cv_scores):
            results[f"cv_fold_{m+1}"] = n

        # Append this run’s results to the compiled list
        compiled_results.append(results)

In [None]:
results_df = pd.DataFrame(compiled_results)

## Figure 8A

In [None]:
# -----------------------------------------------------------------------------
# Visualization: Grouped Bar Plot of Model Performance (ROC AUC by Feature Set)
# -----------------------------------------------------------------------------

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sns.set_style("ticks")
plt.rcParams["pdf.fonttype"] = 42


def create_grouped_bar_plot(
    df,
    feat_set_col,
    model_col,
    score_cols,
    error_bars="sd",
    y_label="Score",
    palette=None,
    bar_colors=None,
    edge_colors=None,
    error_bar_props=None,
):
    """
    Generates a grouped bar plot with error bars from a pandas DataFrame using Seaborn.

    Args:
        df (pd.DataFrame): Input DataFrame containing model evaluation results.
        feat_set_col (str): Column representing the feature set category.
        model_col (str): Column representing the model category.
        score_cols (list of str): List of column names for individual CV fold scores.
        error_bars (str, optional): Type of error bar: 'sd', 'se', 'ci', or None. Default 'sd'.
        y_label (str, optional): Label for the y-axis. Default 'Score'.
        palette (str or dict, optional): Seaborn color palette name or mapping.
        bar_colors (dict, optional): Feature-set → bar color mapping.
        edge_colors (dict, optional): Feature-set → bar edge color mapping.
        error_bar_props (dict, optional): Customization for error bars.

    Returns:
        matplotlib.pyplot: The generated plot (can be saved or shown).
    """
    # Input validation
    if df.empty:
        raise ValueError("Input DataFrame is empty.")
    if not all(col in df.columns for col in [feat_set_col, model_col] + score_cols):
        raise ValueError("One or more specified columns not found in DataFrame.")

    # Default parameters
    bar_colors = bar_colors or {}
    edge_colors = edge_colors or {}
    error_bar_props = error_bar_props or {
        "color": "0.2",
        "linewidth": 1.5,
        "capsize": 0.1,
    }

    # Reshape DataFrame from wide (folds as columns) to long format for Seaborn
    df_melted = df.melt(
        id_vars=[feat_set_col, model_col],
        value_vars=score_cols,
        var_name="Score Type",
        value_name="Score",
    )

    # Initialize the plot
    plt.figure(figsize=(6, 6), dpi=300)

    # Identify unique feature sets
    feature_sets = df[feat_set_col].unique()

    # Assign bar colors if none specified
    if not bar_colors and isinstance(palette, str):
        palette_colors = sns.color_palette(palette, n_colors=len(feature_sets))
        bar_colors = {feat: color for feat, color in zip(feature_sets, palette_colors)}
    elif isinstance(palette, dict):
        bar_colors = palette

    # Assign default edge colors
    for feat in feature_sets:
        if feat not in edge_colors:
            edge_colors[feat] = "0.5"

    # Configure error bar styling
    err_kws = {
        "color": error_bar_props.get("color", "0.2"),
        "linewidth": error_bar_props.get("linewidth", 1.5),
    }
    capsize = error_bar_props.get("capsize", 0.1)

    # Create grouped bar plot
    ax = sns.barplot(
        x=model_col,
        y="Score",
        hue=feat_set_col,
        data=df_melted,
        errorbar=error_bars,
        err_kws=err_kws,
        capsize=capsize,
        palette=bar_colors,
    )

    # Apply custom edge colors for visual clarity
    for i, bars in enumerate(ax.containers):
        feat = feature_sets[i % len(feature_sets)]
        edge_color = edge_colors.get(feat, "0.5")
        for bar in bars:
            bar.set_edgecolor(edge_color)

    # Build legend manually for precise color control and ordering
    handles, labels = ax.get_legend_handles_labels()
    handles = [
        plt.Line2D(
            [0],
            [0],
            linewidth=1,
            linestyle="-",
            marker="s",
            markerfacecolor=bar_colors["Bulk"],
            markeredgecolor=edge_colors["Bulk"],
            markersize=10,
            color="none",
        ),
        plt.Line2D(
            [0],
            [0],
            linewidth=1,
            linestyle="-",
            marker="s",
            markerfacecolor=bar_colors["Single-cell"],
            markeredgecolor=edge_colors["Single-cell"],
            markersize=10,
            color="none",
        ),
        plt.Line2D(
            [0],
            [0],
            linewidth=1,
            linestyle="-",
            marker="s",
            markerfacecolor=bar_colors["Spatial Single-cell"],
            markeredgecolor=edge_colors["Spatial Single-cell"],
            markersize=10,
            color="none",
        ),
    ]
    plt.legend(
        handles,
        labels,
        title="",
        loc="upper center",
        bbox_to_anchor=(0.5, 1.1),
        ncol=3,
        frameon=False,
        fontsize=12,
    )

    # Final plot formatting
    ax.set_ylim(0.85, 1.002)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=14)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=14)
    plt.xlabel("Model", fontsize=16)
    plt.ylabel(y_label, fontsize=16)
    plt.tight_layout()

    return plt


# -----------------------------------------------------------------------------
# Example usage: Visualization of ROC AUC scores across models and feature sets
# -----------------------------------------------------------------------------

# Define color schemes for each feature set
bar_colors = {
    "Bulk": "#3498db",             # Blue
    "Single-cell": "#e74c3c",      # Red
    "Spatial Single-cell": "#2ecc71",  # Green
}
edge_colors = {
    "Bulk": "#3980b0",             # Darker blue
    "Single-cell": "#f03920",      # Darker red
    "Spatial Single-cell": "#27ae60",  # Darker green
}
error_bar_props = {
    "color": "black",
    "linewidth": 1.2,
    "capsize": 0.15,
}

# Generate and save the grouped bar plot
plot = create_grouped_bar_plot(
    results_df,
    "feature_set",
    "model",
    ["cv_fold_1", "cv_fold_2", "cv_fold_3", "cv_fold_4", "cv_fold_5"],
    error_bars="sd",
    y_label="ROC AUC Score",
    bar_colors=bar_colors,
    edge_colors=edge_colors,
    error_bar_props=error_bar_props,
)

# Save and display the plot
plot.savefig("fig_8a.pdf", bbox_inches="tight")
plot.show()


In [None]:
compiled_results_df = pd.DataFrame(compiled_results)

## Figure 8B

In [None]:
cdata = features_dict['per_core_celltype_niche'].copy()
cdata = cdata.reset_index(drop=True)
cdata = cdata[cdata['Stage'] != 'control']
labels = cdata['Stage'].values
cdata.drop(columns=['Stage', 'path_block_core'], inplace=True)
y = np.array([label_mapping[label] for label in labels])


In [None]:
from tqdm.notebook import tqdm
import time

compiled_results = []
mrmr_feature_count = [10, 20, 30, 40, 50, 60, 70, 80, 90]
for k in tqdm(mrmr_feature_count):
    mrmr_start_time = time.time()
    selected_features = mrmr_classif(X=cdata, y=labels, K=k)
    x = cdata[selected_features]
    mrmr_end_time = time.time()
    mrmr_time = mrmr_end_time - mrmr_start_time
    for classifier in ['l1', 'l2', 'xgb']:
        results = {}
        if classifier == 'xgb':
            clf = xgb.XGBClassifier(
                objective='multi:softmax',
                max_depth=6,
                learning_rate=0.5,
                subsample=0.8,
                colsample_bytree=0.8,
                num_class=3,
                eval_metric='auc',
                seed=42
            )
        else:
            clf = SGDClassifier(
                loss='log_loss',             # Using logistic loss for classification
                penalty=classifier,   # Set the classifier to Elastic Net
                alpha=0.01,            # Regularization strength
                l1_ratio=0.5,          # Mix of L1 and L2 regularization
                max_iter=1000,         # Maximum number of iterations
                tol=1e-3,               # Tolerance for stopping criteria
                random_state=42,
            )
        cv_scores = cross_val_score(clf, x, y, cv=5, scoring='roc_auc_ovr', n_jobs=-1)
        results['feature_set'] = feat_set_mapping[feat_set]
        results['model'] = model_name_mapping[classifier]
        results['num_features'] = k
        results['mrmr_time'] = mrmr_time
        
        for i, value in enumerate(cv_scores):
            results[f'cv_fold_{i}'] = value
        compiled_results.append(results)

In [None]:
# -----------------------------------------------------------------------------
# Visualization: Model Performance vs. Number of Selected Features
#
# This script generates a line plot comparing ROC AUC scores across models
# (LASSO, Ridge, XGBoost) as a function of the number of selected features.
# Each model’s mean cross-validation performance is plotted with markers,
# allowing visual comparison of feature selection effects (e.g., for MRMR feature selection).
# -----------------------------------------------------------------------------

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# --- Plot styling ---
sns.set_style("ticks")
plt.rcParams["pdf.fonttype"] = 42

# Define consistent color scheme for models
line_colors = {
    "LASSO": "#ffba49",    # yellow-orange
    "Ridge": "#20a39e",    # teal
    "XGBoost": "#ef5b5b",  # red
}

# --- Prepare data for plotting ---
# Melt the cross-validation results into long format for Seaborn
df_melted = compiled_results_df.melt(
    id_vars=["num_features", "model"],
    value_vars=[f"cv_fold_{i}" for i in range(5)],
    var_name="Score Type",
    value_name="Score",
)

# --- Create figure and axis ---
fig, axes = plt.subplots(1, 1, figsize=(8, 6), dpi=300)

# --- Plot the mean performance across CV folds for each model ---
sns.lineplot(
    x="num_features",
    y="Score",
    hue="model",
    data=df_melted,
    marker="o",
    ax=axes,
    palette=line_colors,  # Custom colors per model
)

# --- Customize legend manually for consistent colors and order ---
handles, labels = axes.get_legend_handles_labels()
handles = [
    plt.Line2D(
        [0], [0],
        linewidth=1, linestyle="-", marker="s",
        markerfacecolor=line_colors["LASSO"],
        markeredgecolor=line_colors["LASSO"],
        markersize=10, color="none",
    ),
    plt.Line2D(
        [0], [0],
        linewidth=1, linestyle="-", marker="s",
        markerfacecolor=line_colors["Ridge"],
        markeredgecolor=line_colors["Ridge"],
        markersize=10, color="none",
    ),
    plt.Line2D(
        [0], [0],
        linewidth=1, linestyle="-", marker="s",
        markerfacecolor=line_colors["XGBoost"],
        markeredgecolor=line_colors["XGBoost"],
        markersize=10, color="none",
    ),
]
plt.legend(
    handles,
    labels,
    title="",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    frameon=False,
    fontsize=12,
)

# --- Axis formatting ---
axes.set_ylim(0.9, 1.0)
axes.set_xlim(10, 90)
axes.set_xlabel("Number of Features", fontsize=16)
axes.set_ylabel("ROC AUC Score", fontsize=16)
axes.set_xticklabels(axes.get_xticklabels(), fontsize=14)
axes.set_yticklabels(axes.get_yticklabels(), fontsize=14)

plt.tight_layout()

# --- Save and display ---
plt.savefig("mrmr_feature_selection.pdf", dpi=300, bbox_inches="tight")
plt.show()


## Figure 8C

In [None]:
import pickle
with open('../data/gene_features.pkl', 'rb') as f:
    features = pickle.load(f)

In [None]:
niche_mapping = {
    #SKIN
    "Skin 1": "Basal epidermis",
    "Skin 2": "Diff epidermis",
    # STROMA
    "Stroma": "Stroma",
    "TA VEC Stroma": "TA VEC stroma",
    
    # IMMUNE
    "Macrophage Immune Stroma": "M$\phi$ stroma",
    "T-cell Immune Stroma": "T cell stroma",
    "Immune": "Immune",
    # TUMOR
    "Tumor Core": "Tumor core",
    "Tumor": "Tumor",
    "Tumor Boundary": "Tumor boundary"
}

niche_colors = {
    #SKIN
    "Basal epidermis": "#b299e3",
    "Diff epidermis": "#ffd000",
    # STROMA
    "Stroma": "#646500",
    "TA VEC stroma": "#00e50c",
    "VEC stroma": "#cccc33",
    
    # IMMUNE
    "M$\phi$ stroma": "#00dbf4",
    "T cell stroma": "#0051f9",
    "Immune": "#c100f9",
    # TUMOR
    "Tumor core": "#450000",
    "Tumor": "#eb0000",
    "Tumor boundary": "#faa0aa"
}

celltype_mapping = {
    'Lymphatic Endothelial Cells': 'LECs',
    'Macrophages': 'M$\phi$',
    'Vascular Endothelial Cells': 'VECs',
    'Pericytes': 'PCs',
    'Fibroblasts': 'Fbs',
    'T-cells': 'T cells',
    'Keratinocytes': 'KCs',
    'Dendritic cells': 'DCs',
    'Spinous to Granular Cells': 'StoGs',
    'Pilosebaceous Cells': 'PSCs',
    'B-cells': 'B cells',
    'Melanocytes': 'MCs'
}

broad_cell_types_color_mapping = {
    'LECs': '#ffb695',
    'M$\phi$': '#ff40ff',
    'VECs': '#a4e000',
    'PCs': '#9f7704',
    'Fbs': '#c7d0c0',
    'T cells': '#941100',
    'KCs': '#181c82',
    'DCs': '#ff9300',
    'StoGs': '#034cff',
    'PSCs': '#bbbde2',
    'B cells': '#f12d00',
    'MCs': '#00bbbf'
}

gsea_genes = {
'LEC Identity': ['PROX1', 'CD34'],
'Angiogenesis': ['COL6A2', 'ITGB1', 'CD93', 'VIM'],
'Immune remodeling' : ['IL2RA', 'FCER1G', 'CXCL12', 'ARHGDIB'],
'Mesenchymal/EndMT' : ['SFRP2', 'FSCN1', 'GNG11', 'CTNNB1'],
'Skin-associated' : ['DEFB1', 'CLDN1', 'ID4', 'ASPN']
}

gsea_colors = {
'LEC Identity': '#f15bb5',
'Angiogenesis': '#f7567c',
'Immune remodeling' : '#a9def9',
'Mesenchymal/EndMT' : '#ffd60a',
'Skin-associated' : '#06d6a0'
}

niche_order = ['Tumor core', 'Tumor', 'Tumor boundary', 'TA VEC stroma',  'M$\phi$ stroma', 'Stroma', 'Basal epidermis', 'Diff epidermis']


gene2color_mapping = {}
for gsea, gene_list in gsea_genes.items():
    for gene in gene_list:
        gene2color_mapping[gene] = gsea_colors[gsea]
 
 
def forward_transform(x, xmin, p):
    return ((x - xmin) ** p)

def inverse_transform(y, xmin, p):
    return (y ** (1/p)) + xmin

# def sci_fomrat(n, decimals=2):
def sci_notation(x):
    if np.isclose(x, 1e-5):
        return r'$10^{-5}$'
    if np.isclose(x, 1e-3):
        return r'$10^{-3}$'
    if np.isclose(x, 1e-2):
        return r'$10^{-2}$'
    if np.isclose(x, 1e-1):
        return r'$10^{-1}$'
    if np.isclose(x, 0.9):
        return r'1'
    
vals_reqd = [1e-5, 1e-3, 1e-2, 0.1, 0.9]
# base = 5
# epsilon = 1e-5
# exps = np.log(vals_reqd)/np.log(base)
# ft = forward_transform(exps, xmin, power_exp)


In [None]:
# for feat_set in range(3):

cdata = features.copy()
cdata = cdata.reset_index(drop=True)
cdata = cdata[cdata['Stage'] != 'control']
labels = cdata['Stage'].values
cdata.drop(columns=['Stage', 'path_block_core'], inplace=True)


In [None]:
from mrmr import mrmr_classif
selected_features, relevance, redundancy = mrmr_classif(X=cdata, y=labels, K=80, return_scores=True)


In [None]:
from sklearn.preprocessing import LabelBinarizer
from mrmr import mrmr_classif
import pandas as pd

# Assume cdata is your feature DataFrame and labels is a 1D array or Series of class labels
lb = LabelBinarizer()
binary_labels = lb.fit_transform(labels)

class_names = [str(cls) for cls in lb.classes_]

results = []
x = cdata[selected_features]
for i, class_name in enumerate(class_names):
    y_bin = binary_labels[:, i]
    selected, relevance, redundancy = mrmr_classif(X=x, y=y_bin, K=80, return_scores=True)
    for i, feat in enumerate(selected):
        results.append({
            'id': i,
            'label': class_name,
            'selected_features': selected[i],
            'relevance_scores': relevance[feat],
        })


In [None]:
results_df = pd.DataFrame(results)
results_df.to_csv('../data/mrmr_relevance_scores.csv', index=False)

In [None]:
results_df = pd.read_csv('../data/mrmr_relevance_scores.csv')

In [None]:
top_features = list(results_df['selected_features'].unique())
top_features.append('Stage')
mrmr_avx_features = features[top_features]

In [None]:
genex_top_features = []
for feat in top_features:
    if feat == 'Stage':
        continue
    mean_values = {}
    std_values = {}
    for stage in ['nodular', 'plaque', 'patch']:
        values = mrmr_avx_features[mrmr_avx_features['Stage'] == stage][feat]
        mean_values[stage] = values.mean()
        std_values[stage] = values.std()
    expr = {
        'Niche': feat.split('_')[0],
        'Cell Type': feat.split('_')[1],
        'Gene': feat.split('_')[2],
        'Patch' : mean_values['patch'],
        'Patch_std': std_values['patch'],
        'Plaque' : mean_values['plaque'],
        'Plaque_std': std_values['plaque'],
        'Nodular' : mean_values['nodular'],
        'Nodular_std': std_values['nodular'],
        'selected_features' : feat
        
    }
    genex_top_features.append(expr)
    

In [None]:
gx_features_df = pd.DataFrame(genex_top_features)


In [None]:
gx_features_df['Niche'] = gx_features_df['Niche'].map(niche_mapping).fillna(gx_features_df['Niche'])
gx_features_df['Cell Type'] = gx_features_df['Cell Type'].map(celltype_mapping).fillna(gx_features_df['Cell Type'])


In [None]:
gene_df = gx_features_df.groupby(['Niche', 'Cell Type', 'Gene', 'Patch', 'Plaque', 'Nodular'], observed=True).size()
gene_df = gene_df.reindex(order, level='Niche').reset_index()
gene_df['Niche'] = pd.Categorical(gene_df['Niche'], categories=niche_order, ordered=True)
gene_df = gene_df.sort_values(['Niche',  'Nodular', 'Cell Type'], ignore_index=True)#, ascending=[True, False])

In [None]:
niche_df = gene_df.groupby(['Niche'], observed=True, sort=False).size().reset_index()
ct_df = gene_df.groupby(['Niche', 'Cell Type'], observed=True, sort=False).size().reset_index()


In [None]:
def rotate_labels(mypie, sizes, labels, radius, genes=False):
    # Add radial labels
    # Add radial text
    angle = 90  # startangle
    for i, (wedge, label) in enumerate(zip(mypie, labels)):
        r = radius
        theta = (wedge.theta2 + wedge.theta1) / 2.0
        radians = np.deg2rad(theta)
        x = np.cos(radians)
        y = np.sin(radians)
    
        # Adjust rotation for readability
        if theta <= 90:
            rotation = theta
            alignment = 'left'
        elif theta <= 180:
            rotation = theta - 180
            alignment = 'right'
        elif theta <= 270:
            rotation = theta-180
            alignment = 'right'
        elif theta <= 360:
            rotation = theta
            alignment = 'left'
        else:
            rotation = theta-360
            alignment = 'left'

        fontsize = 8
        color = 'black'
        if 'stroma' in label:
                fontsize = 8
        if label in ['Stroma']:
            color='white'
            r += 0.25
        if label in ['Tumor core']:
            r += 0.05
        # if label in ['Basal Epidermis']:
        #     r -= 0.3
        if label in ['KCs', 'StoGs', 'T cells', 'Stroma', 'Tumor', 'Tumor core']:
            color='white'
        if genes:
            fontsize=7
        ax.text(x * r, y * r, label, rotation=rotation,
                ha=alignment, va='center',color=color,
                rotation_mode='anchor', fontsize=fontsize)
    
        angle -= theta  # Update start angle


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import Patch

# 1 · Load → pivot → split → map
plt.rcParams['pdf.fonttype'] = 42

results = gx_features_df.copy()

# make sure the column is really string dtype first
results["selected_features"] = results["selected_features"].astype("string")

# use named arguments ─ no ambiguity across pandas versions
results[["Niche", "Cell Type", "Gene"]] = (
    results["selected_features"]
          .str.split(pat="_", n=2, expand=True)
)

# Shorten labels
results["Niche"]     = results["Niche"].map(niche_mapping).fillna(results["Niche"])
results["Cell Type"] = results["Cell Type"].map(celltype_mapping).fillna(results["Cell Type"])
results["Niche"] = pd.Categorical(results["Niche"], categories=niche_order, ordered=True)

# Cell-type order *within* each niche → take first appearance order
# keep the *order*, but drop repeats
celltype_order = (
    results.sort_values("selected_features")         # reproducible baseline
           .drop_duplicates(["Niche", "Cell Type"])  # first CT in each niche
           .sort_values(["Niche"])["Cell Type"]      # still contains repeats
           .tolist()
)
celltype_order = list(dict.fromkeys(celltype_order))   # ← removes duplicates

results["Cell Type"] = pd.Categorical(
    results["Cell Type"],
    categories=celltype_order,
    ordered=True
)


In [None]:


# 4 · **Key line** – sort genes inside each group by nodular_norm ↓
gene_df = results.sort_values(["Niche", "Cell Type", "Nodular"],
                              ascending=[True, True, False],
                              ignore_index=True)

# Derive all ring ingredients straight from gene_df
#     (sizes, labels, heat-map colours) – zero manual syncing
# ring-3 gene layer
gene_sizes  = np.ones(len(gene_df))
gene_labels = gene_df["Gene"]
gene_cols = []
for gene in gene_labels:
    gene_cols.append(gene2color_mapping.get(gene, '#edf2f4'))

# ring-2 cell-type layer
ct_sizes   = (gene_df.groupby(["Niche", "Cell Type"], observed=True, sort=False)
                      .size()
                      .to_numpy())
ct_labels  = gene_df.drop_duplicates(["Niche", "Cell Type"])["Cell Type"]

# ring-1 niche layer
niche_sizes  = (gene_df.groupby("Niche", observed=True, sort=False)
                        .size()
                        .to_numpy())
niche_labels = niche_order

expr_cmap = plt.get_cmap("coolwarm")
# Normalize using log scale
patch_values = gene_df['Patch'].tolist()
plaque_values = gene_df['Plaque'].tolist()
nodular_values = gene_df['Nodular'].tolist()

base = 5
epsilon = 1e-5
base = 5
epsilon = 1e-5
exps = np.log(vals_reqd)/np.log(base)
patch_values = np.array(patch_values) + epsilon
plaque_values = np.array(plaque_values) + epsilon
nodular_values = np.array(nodular_values) + epsilon

patch_log = np.log(patch_values) / np.log(base)
plaque_log = np.log(plaque_values) / np.log(base)
nodular_log = np.log(nodular_values) / np.log(base)

all_log = np.concatenate([patch_log, plaque_log, nodular_log])

power_exp = 2
scaled_log = np.power(all_log - all_log.min(), power_exp)
norm = Normalize(vmin=scaled_log.min(), vmax=scaled_log.max())

patch_cols = [expr_cmap(norm(np.power(v - all_log.min(), power_exp))) for v in patch_log]
plaque_cols = [expr_cmap(norm(np.power(v - all_log.min(), power_exp))) for v in plaque_log]
nodular_cols = [expr_cmap(norm(np.power(v - all_log.min(), power_exp))) for v in nodular_log]

# Your normalization parameters
xmin = all_log.min()
ft = forward_transform(exps, xmin, power_exp)

# Choose tick locations in norm space (0 to 1)
norm_tick_positions = ft#np.linspace(scaled_log.min(), scaled_log.max(), 5)

# Convert normalized ticks back to data scale (log space)
log_tick_values = inverse_transform(norm_tick_positions, xmin, power_exp)

# Convert log scale ticks back to original data scale (undo log base transform)
original_tick_values = base ** log_tick_values  # since log_base(x) = log(x)/log(base)

# Format labels nicely (e.g., 2 decimals)
tick_labels = [sci_notation(v) for v in original_tick_values]

# ──────────────────────────────────────────────────────────────────────────
# 6 · P L O T
# ──────────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(8, 8), dpi=300)
ax.axis('equal')                    # keep circles perfectly round
start_angle = 90                   # put first slice at noon

# ---------- Ring 1  ·  N I C H E S ---------------------------------------
radius1, width1 = 1.0, 0.70
wedges1, _ = ax.pie(
    niche_sizes,
    radius      = radius1,
    startangle  = start_angle,
    colors      = [niche_colors[n] for n in niche_labels],
)
plt.setp(wedges1, width=width1, edgecolor='white')
rotate_labels(wedges1, niche_sizes, niche_labels, radius=radius1 - 0.65)

# ---------- Ring 2  ·  C E L L T Y P E S ----------------------------------
radius2, width2 = radius1 + 0.30, 0.30
wedges2, _ = ax.pie(
    ct_sizes,
    radius      = radius2,
    startangle  = start_angle,
    colors      = [broad_cell_types_color_mapping[c] for c in ct_labels],
)
plt.setp(wedges2, width=width2, edgecolor='white')
rotate_labels(wedges2, ct_sizes, ct_labels, radius=radius2 - 0.25)

# ---------- Ring 3  ·  G E N E S (sorted by nodular_norm) -----------------
radius3, width3 = radius2 + 0.45, 0.45
wedges3, _ = ax.pie(
    gene_sizes,
    radius      = radius3,
    startangle  = start_angle,
    colors      = gene_cols#['#F3DFBF'] * len(gene_labels),   # light beige
)
plt.setp(wedges3, width=width3, edgecolor='white')
rotate_labels(wedges3, gene_sizes, gene_labels,
              radius=radius3 - 0.40, genes=True)

# ---------- Rings 4–6  ·  H E A T  M A P  T R A C K S ---------------------
thin_width = 0.10
for i, (radius, colours) in enumerate(
        [(radius3 + 0.10, patch_cols),
         (radius3 + 0.20, plaque_cols),
         (radius3 + 0.30, nodular_cols)]):
    wedges, _ = ax.pie(
        gene_sizes,
        radius      = radius,
        startangle  = start_angle,
        colors      = colours,
        wedgeprops  = dict(edgecolor='white', width=thin_width),
    )

# ---------- Mini colour-bars (Patch / Plaque / Nodular) -------------------
# cmap_triplet  = [plt.get_cmap('Greens'), plt.get_cmap('Blues'), plt.get_cmap('Reds')]
# cbar_labels   = ['Patch', 'Plaque', 'Nodular']

# for idx, (cmap, lbl) in enumerate(zip(cmap_triplet, cbar_labels)):
sm = ScalarMappable(norm=norm, cmap=expr_cmap)
axins = inset_axes(
    ax,
    width          = "2%",   # relative to parent
    height         = "30%",
    loc            = 'upper center',
    bbox_to_anchor = (0.8, -0.5, 1, 2),
    bbox_transform = ax.transAxes,
    borderpad      = 0,
)

# Create colorbar with ticks at norm_tick_positions but labeled with original values
cbar = fig.colorbar(sm, cax=axins, orientation='vertical')
cbar.set_ticks(norm_tick_positions)
cbar.set_ticklabels(tick_labels)
cbar.set_label('Gene Expression', fontsize=10, labelpad=-45, loc='center')
cbar.ax.tick_params(labelsize=8)


colors = gsea_colors.values()
labels = gsea_colors.keys()


# Create custom legend handles
legend_patches = [Patch(color=col, label=lab) for col, lab in zip(colors, labels)]

# Add the legend to the plot
# ax.legend(handles=legend_patches, title="Pathways", prop={'size': 8})
plt.legend(handles=legend_patches,
          loc='upper left',
          bbox_to_anchor=(.5, -1.5),
          frameon=False,         # No border
          labelspacing=0.3,      # Reduce space between labels
          handlelength=1,        # Shorter handles
          handletextpad=0.3,     # Less space between handle and text
          borderpad=0.4,         # Less padding inside legend box
          # prop={'size': 8}
         )
plt.tight_layout()
# plt.savefig(output_folder/'fig_7b_mrmr_sunburst.pdf', dpi=300, bbox_inches='tight')
# plt.show()
 

## Save XGBoost model

In [None]:
import pickle
from mrmr import mrmr_classif
features = pickle.load(open('features_per_core_celltype_niche.pkl', 'rb'))

In [None]:
import numpy as np
cdata = features[2].copy()
cdata = cdata.reset_index(drop=True)
cdata = cdata[cdata['Stage'] != 'control']
labels = cdata['Stage'].values
cdata.drop(columns=['Stage', 'path_block_core'], inplace=True)

label_mapping = {
    'patch': 0,
    'plaque': 1,
    'nodular': 2
}
y = np.array([label_mapping[label] for label in labels])



In [None]:
import xgboost as xgb
selected_features = mrmr_classif(X=cdata, y=labels, K=80)

In [None]:
from sklearn.model_selection import cross_val_score
x = cdata[selected_features]

clf = xgb.XGBClassifier(
    objective='multi:softmax',
    max_depth=6,
    learning_rate=0.5,
    subsample=0.8,
    colsample_bytree=0.8,
    num_class=3,
    eval_metric='auc',
    seed=42
)
clf.fit(x, y)
clf.save_model("xgb_stage_classifier.json")