# Preprocessing

In [1]:
import re
from typing import List, Dict

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import RobustScaler, StandardScaler

from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_ipcw
from sksurv.util import Surv

from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.utils import concordance_index

import optuna
from tqdm.notebook import tqdm
import xgboost as xgb

# Set plotly as pandas plotting backend
pd.options.plotting.backend = "plotly"

## 0. Load data

In [2]:
df = pd.read_csv("../../data/clinical_train.csv")
df_eval = pd.read_csv("../../data/clinical_val.csv")
maf_df = pd.read_csv("../../data/molecular_train.csv")
maf_eval = pd.read_csv("../../data/molecular_val.csv")
target_df = pd.read_csv("../../data/target_train.csv")

## 1. Missing values â€” Clinical / Molecular

In [3]:
# Remove IDs and center info for missingness analysis
df_noid = df.drop(columns=["ID", "CENTER"], errors="ignore")
maf_noid = maf_df.drop(columns=["ID"], errors="ignore")

# -----------------------
# 1.1 Missing values (clinical)
# -----------------------
missing_clinical = df_noid.isnull().sum()
missing_clinical_pct = (missing_clinical / len(df_noid) * 100).sort_values()

fig1 = go.Figure(
    go.Bar(
        x=missing_clinical_pct.values,
        y=missing_clinical_pct.index,
        orientation="h",
        text=[f"{v:.1f}%" for v in missing_clinical_pct.values],
        textposition="outside",
    )
)

fig1.update_layout(
    xaxis_title="Percentage of Missing Values (%)",
    yaxis_title="Features",
    margin=dict(l=120, r=40, t=10, b=40),
    showlegend=False,
)
fig1.update_xaxes(showgrid=False)
fig1.update_yaxes(showgrid=False)

fig1.show()

# -----------------------
# 1.2 Missing-value heatmap (clinical)
# -----------------------
missing_matrix = df_noid.isnull().astype(int)

sample_size = min(500, len(df_noid))
missing_sample = missing_matrix.sample(n=sample_size, random_state=42)

z = missing_sample.T.values
x = missing_sample.index.astype(str)
y = missing_sample.columns.astype(str)

fig2 = go.Figure(
    go.Heatmap(
        z=z,
        x=x,
        y=y,
        colorscale=[[0, "lightblue"], [1, "darkred"]],
        colorbar=dict(title="Missing"),
    )
)

fig2.update_layout(
    xaxis_title="Patients",
    yaxis_title="Features",
    margin=dict(l=120, r=40, t=10, b=40),
)
fig2.update_xaxes(showgrid=False)
fig2.update_yaxes(showgrid=False)

fig2.show()

# -----------------------
# 1.3 Missing values (molecular)
# -----------------------
missing_molecular = maf_noid.isnull().sum()
missing_molecular_pct = (missing_molecular / len(maf_noid) * 100).sort_values()

fig3 = go.Figure(
    go.Bar(
        x=missing_molecular_pct.values,
        y=missing_molecular_pct.index,
        orientation="h",
        text=[f"{v:.1f}%" for v in missing_molecular_pct.values],
        textposition="outside",
    )
)

fig3.update_layout(
    xaxis_title="Percentage of Missing Values (%)",
    yaxis_title="Features",
    margin=dict(l=120, r=40, t=10, b=40),
    showlegend=False,
)
fig3.update_xaxes(showgrid=False)
fig3.update_yaxes(showgrid=False)

fig3.show()

# -----------------------
# 1.4 Mutations per patient
# -----------------------
mutations_per_patient = maf_df.groupby("ID").size()

fig4 = go.Figure(
    go.Histogram(
        x=mutations_per_patient.values,
        nbinsx=50,
    )
)

fig4.update_traces(marker_line_width=1, marker_line_color="black")
fig4.update_layout(
    xaxis_title="Number of Mutations",
    yaxis_title="Number of Patients",
    margin=dict(l=60, r=40, t=10, b=40),
    showlegend=False,
)
fig4.update_xaxes(showgrid=False)
fig4.update_yaxes(showgrid=False)

fig4.show()

# -----------------------
# 1.5 Clinical vs Molecular comparison (side-by-side)
# -----------------------
fig5 = make_subplots(rows=1, cols=2)

fig5.add_trace(
    go.Bar(
        x=missing_clinical_pct.values,
        y=missing_clinical_pct.index,
        orientation="h",
        text=[f"{v:.1f}%" for v in missing_clinical_pct.values],
        textposition="outside",
        marker_color="salmon",
    ),
    row=1,
    col=1,
)

fig5.add_trace(
    go.Bar(
        x=missing_molecular_pct.values,
        y=missing_molecular_pct.index,
        orientation="h",
        text=[f"{v:.1f}%" for v in missing_molecular_pct.values],
        textposition="outside",
        marker_color="skyblue",
    ),
    row=1,
    col=2,
)

fig5.update_xaxes(title_text="Missing %", showgrid=False, row=1, col=1)
fig5.update_yaxes(title_text="Features", showgrid=False, row=1, col=1)

fig5.update_xaxes(title_text="Missing %", showgrid=False, row=1, col=2)
fig5.update_yaxes(showgrid=False, row=1, col=2)

fig5.update_layout(
    margin=dict(l=120, r=40, t=10, b=40),
    showlegend=False,
)

fig5.show()

# -----------------------
# 1.6 Summary statistics
# -----------------------
print("\n=== Summary Statistics ===")
print("\nClinical data (df, excluding ID & CENTER):")
print(f"  Total samples: {len(df_noid)}")
print(
    f"  Features with missing values: "
    f"{(missing_clinical > 0).sum()}/{len(missing_clinical)}"
)
print(f"  Average missing rate: {missing_clinical_pct.mean():.2f}%")

print("\nMolecular data (maf_df, excluding ID):")
print(f"  Total mutations (rows): {len(maf_df)}")
print(
    f"  Features with missing values: "
    f"{(missing_molecular > 0).sum()}/{len(missing_molecular)}"
)
print(f"  Average missing rate: {missing_molecular_pct.mean():.2f}%")


=== Summary Statistics ===

Clinical data (df, excluding ID & CENTER):
  Total samples: 3323
  Features with missing values: 7/7
  Average missing rate: 7.72%

Molecular data (maf_df, excluding ID):
  Total mutations (rows): 10935
  Features with missing values: 8/10
  Average missing rate: 0.72%


## 2. Survival target and base clinical features

In [4]:
# Reload molecular train/val (if needed later).
maf_df2 = pd.read_csv("../../data/molecular_train.csv")
maf_eval2 = pd.read_csv("../../data/molecular_val.csv")

# Survival target columns
target = ["OS_YEARS", "OS_STATUS"]

# Clean target dataframe
target_df["OS_YEARS"] = pd.to_numeric(target_df["OS_YEARS"], errors="coerce")
target_df["OS_STATUS"] = target_df["OS_STATUS"].astype(bool)
target_df = target_df.dropna(subset=target)

# Base clinical features
features = ["ID", "BM_BLAST", "WBC", "ANC", "MONOCYTES", "HB", "PLT", "CYTOGENETICS"]

X = df.loc[df["ID"].isin(target_df["ID"]), features].copy()
X_eval = df_eval.loc[:, features].copy()

# sksurv-style survival object (not used later but kept for compatibility)
y = Surv.from_dataframe("OS_STATUS", "OS_YEARS", target_df)

# 3. Impute Missing Values (XGBoost + Optuna per feature)

In [None]:
# Numerical columns
num_cols = X.select_dtypes("number").columns

# Add missingness indicators for numerical columns before imputation
for df_ in (X, X_eval):
    for col in num_cols:
        df_[f"{col}_missing"] = df_[col].isna().astype("int8")

X_num = X[num_cols].to_numpy(dtype=float)
n_rows, n_features = X_num.shape

base_xgb_params = dict(
    objective="reg:squarederror",
    tree_method="hist",
    n_jobs=-1,
    random_state=0,
)

n_trials = 1000
n_splits_cv = 5  
best_params_per_feature: Dict[str, Dict] = {}

from sklearn.model_selection import KFold


def make_objective_for_feature(j: int, n_splits: int = 5):
    """Optuna objective for a single numerical feature j with K-fold CV."""
    def objective(trial):
        xgb_params_trial = {
            "max_depth": trial.suggest_int("max_depth", 2, 10),
            "learning_rate": trial.suggest_float(
                "learning_rate", 0.01, 0.3, log=True
            ),
            "n_estimators": trial.suggest_int("n_estimators", 50, 500),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
        }

        y_col = X_num[:, j]
        X_features = np.delete(X_num, j, axis=1)

        not_nan_y = ~np.isnan(y_col)
        n_obs = not_nan_y.sum()
        if n_obs < max(20, n_splits):
            # Not enough data to train a stable model / do CV
            return 1e6

        y_obs = y_col[not_nan_y]
        X_obs = X_features[not_nan_y]

        kf = KFold(n_splits=n_splits, shuffle=True, random_state=0)

        fold_mses = []
        for train_idx, val_idx in kf.split(X_obs):
            X_train, X_val = X_obs[train_idx], X_obs[val_idx]
            y_train, y_val = y_obs[train_idx], y_obs[val_idx]

            model = xgb.XGBRegressor(
                **base_xgb_params,
                **xgb_params_trial,
            )
            model.fit(X_train, y_train)

            y_pred = model.predict(X_val)
            mse = np.mean((y_pred - y_val) ** 2)
            fold_mses.append(mse)
        return float(np.mean(fold_mses))

    return objective


# Optuna tuning for each numerical feature
for j, col in enumerate(num_cols):
    y_col = X_num[:, j]
    if (~np.isnan(y_col)).sum() < max(20, n_splits_cv):
        # Too few non-missing values, skip this feature
        continue

    print(f"Tuning feature {j + 1}/{len(num_cols)}: {col}")
    study = optuna.create_study(direction="minimize")

    with tqdm(total=n_trials, desc=col) as pbar:
        def callback(study_, trial_):
            pbar.set_description(f"{col} | best MSE={study_.best_value:.5f}")
            pbar.update(1)

        study.optimize(
            make_objective_for_feature(j, n_splits=n_splits_cv),
            n_trials=n_trials,
            callbacks=[callback],
        )

    best_params_per_feature[col] = {**base_xgb_params, **study.best_params}

print("Number of tuned numerical features:", len(best_params_per_feature))


def fit_xgb_imputers_per_feature(
    X_df: pd.DataFrame,
    numeric_cols: pd.Index,
    best_params: Dict[str, Dict],
) -> Dict[str, xgb.XGBRegressor]:
    """Train one XGBoost regressor per feature for imputation."""
    models: Dict[str, xgb.XGBRegressor] = {}
    X_values = X_df[numeric_cols].to_numpy(dtype=float)

    for j, col in enumerate(numeric_cols):
        y_col = X_values[:, j]
        missing_mask = np.isnan(y_col)
        not_missing = ~missing_mask

        if not_missing.sum() < 20:
            continue
        if col not in best_params:
            continue

        X_features = np.delete(X_values, j, axis=1)
        X_train = X_features[not_missing]
        y_train = y_col[not_missing]

        params = best_params[col]
        model = xgb.XGBRegressor(**params)
        model.fit(X_train, y_train)

        models[col] = model

    return models


def transform_with_xgb_imputers(
    X_df: pd.DataFrame,
    numeric_cols: pd.Index,
    models: Dict[str, xgb.XGBRegressor],
) -> pd.DataFrame:
    """Apply trained imputers to fill missing values."""
    X_values = X_df[numeric_cols].to_numpy(dtype=float)

    for j, col in enumerate(numeric_cols):
        missing_mask = np.isnan(X_values[:, j])

        if missing_mask.any():
            if col in models:
                X_features = np.delete(X_values, j, axis=1)
                X_missing = X_features[missing_mask]
                preds = models[col].predict(X_missing)
                X_values[missing_mask, j] = preds
            else:
                X_values[missing_mask, j] = np.nanmedian(X_values[:, j])

    return pd.DataFrame(X_values, columns=numeric_cols, index=X_df.index)


# Final imputation
xgb_models = fit_xgb_imputers_per_feature(X, num_cols, best_params_per_feature)

X[num_cols] = transform_with_xgb_imputers(X, num_cols, xgb_models)
X_eval[num_cols] = transform_with_xgb_imputers(X_eval, num_cols, xgb_models)

# Save imputed clinical data
X.to_csv("../../data/clinical_train_imputed.csv", index=False)
X_eval.to_csv("../../data/clinical_val_imputed.csv", index=False)

[I 2025-12-30 04:46:58,042] A new study created in memory with name: no-name-bbdebc25-22b9-474d-a909-b8d0c8a3ceb5


Tuning feature 1/6: BM_BLAST


BM_BLAST:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 04:46:58,254] Trial 0 finished with value: 57.36320223756497 and parameters: {'max_depth': 8, 'learning_rate': 0.1827554027630706, 'n_estimators': 55, 'subsample': 0.5527458340228233, 'colsample_bytree': 0.6312440369238128}. Best is trial 0 with value: 57.36320223756497.
[I 2025-12-30 04:46:58,505] Trial 1 finished with value: 48.44497073508845 and parameters: {'max_depth': 4, 'learning_rate': 0.2533736758389521, 'n_estimators': 343, 'subsample': 0.6242523208529147, 'colsample_bytree': 0.892639545162077}. Best is trial 1 with value: 48.44497073508845.
[I 2025-12-30 04:46:59,095] Trial 2 finished with value: 48.917274886022696 and parameters: {'max_depth': 9, 'learning_rate': 0.0527140594773034, 'n_estimators': 295, 'subsample': 0.8637360215208223, 'colsample_bytree': 0.6103420854260369}. Best is trial 1 with value: 48.44497073508845.
[I 2025-12-30 04:46:59,231] Trial 3 finished with value: 53.18868821924963 and parameters: {'max_depth': 4, 'learning_rate': 0.0154319561158

Tuning feature 2/6: WBC


WBC:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 04:50:24,995] Trial 0 finished with value: 30.32627806452922 and parameters: {'max_depth': 3, 'learning_rate': 0.19503075478691867, 'n_estimators': 120, 'subsample': 0.8464501659273154, 'colsample_bytree': 0.5955387679665406}. Best is trial 0 with value: 30.32627806452922.
[I 2025-12-30 04:50:25,117] Trial 1 finished with value: 68.31350961307841 and parameters: {'max_depth': 2, 'learning_rate': 0.011694581497871854, 'n_estimators': 196, 'subsample': 0.6598686048461386, 'colsample_bytree': 0.9242119034938916}. Best is trial 0 with value: 30.32627806452922.
[I 2025-12-30 04:50:26,637] Trial 2 finished with value: 17.392473765072182 and parameters: {'max_depth': 10, 'learning_rate': 0.07963808372615992, 'n_estimators': 418, 'subsample': 0.7830824873792886, 'colsample_bytree': 0.9230975885777267}. Best is trial 2 with value: 17.392473765072182.
[I 2025-12-30 04:50:26,785] Trial 3 finished with value: 18.631621534635233 and parameters: {'max_depth': 2, 'learning_rate': 0.1039

Tuning feature 3/6: ANC


ANC:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 04:57:52,644] Trial 0 finished with value: 17.01578005641804 and parameters: {'max_depth': 9, 'learning_rate': 0.04710955651775424, 'n_estimators': 144, 'subsample': 0.7719243634711319, 'colsample_bytree': 0.6652108362889635}. Best is trial 0 with value: 17.01578005641804.
[I 2025-12-30 04:57:52,756] Trial 1 finished with value: 7.710915534360851 and parameters: {'max_depth': 3, 'learning_rate': 0.014071484293268125, 'n_estimators': 166, 'subsample': 0.6472980467551754, 'colsample_bytree': 0.8561874442690285}. Best is trial 1 with value: 7.710915534360851.
[I 2025-12-30 04:57:53,348] Trial 2 finished with value: 13.355661556012086 and parameters: {'max_depth': 6, 'learning_rate': 0.17164003754476734, 'n_estimators': 478, 'subsample': 0.9881260407909004, 'colsample_bytree': 0.5650699178270444}. Best is trial 1 with value: 7.710915534360851.
[I 2025-12-30 04:57:54,555] Trial 3 finished with value: 3.860182135417691 and parameters: {'max_depth': 6, 'learning_rate': 0.1964134

Tuning feature 4/6: MONOCYTES


MONOCYTES:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 05:03:16,713] Trial 0 finished with value: 1.4933409051066655 and parameters: {'max_depth': 7, 'learning_rate': 0.04179762277282501, 'n_estimators': 189, 'subsample': 0.8403138350868453, 'colsample_bytree': 0.8917253791201287}. Best is trial 0 with value: 1.4933409051066655.
[I 2025-12-30 05:03:17,051] Trial 1 finished with value: 3.094710458073573 and parameters: {'max_depth': 10, 'learning_rate': 0.11332724213610339, 'n_estimators': 135, 'subsample': 0.5756461516644489, 'colsample_bytree': 0.9026338011082914}. Best is trial 0 with value: 1.4933409051066655.
[I 2025-12-30 05:03:17,322] Trial 2 finished with value: 3.6647147576131127 and parameters: {'max_depth': 8, 'learning_rate': 0.1543330519465815, 'n_estimators': 142, 'subsample': 0.7877374287613454, 'colsample_bytree': 0.5415608085865411}. Best is trial 0 with value: 1.4933409051066655.
[I 2025-12-30 05:03:17,624] Trial 3 finished with value: 1.8741970604536466 and parameters: {'max_depth': 6, 'learning_rate': 0.109

Tuning feature 5/6: HB


HB:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 05:10:35,721] Trial 0 finished with value: 5.171620370249816 and parameters: {'max_depth': 10, 'learning_rate': 0.09484665916463811, 'n_estimators': 301, 'subsample': 0.817578101143078, 'colsample_bytree': 0.5034287697602233}. Best is trial 0 with value: 5.171620370249816.
[I 2025-12-30 05:10:35,834] Trial 1 finished with value: 3.7055750336921833 and parameters: {'max_depth': 3, 'learning_rate': 0.010401056811700803, 'n_estimators': 414, 'subsample': 0.9426420801732199, 'colsample_bytree': 0.8024008491172782}. Best is trial 1 with value: 3.7055750336921833.
[I 2025-12-30 05:10:35,979] Trial 2 finished with value: 4.295327378681397 and parameters: {'max_depth': 9, 'learning_rate': 0.06200484636601686, 'n_estimators': 120, 'subsample': 0.9746539932323544, 'colsample_bytree': 0.5289495906098465}. Best is trial 1 with value: 3.7055750336921833.
[I 2025-12-30 05:10:36,390] Trial 3 finished with value: 3.740864905183953 and parameters: {'max_depth': 6, 'learning_rate': 0.06861

Tuning feature 6/6: PLT


PLT:   0%|          | 0/1000 [00:00<?, ?it/s]

[I 2025-12-30 05:13:47,388] Trial 0 finished with value: 26622.16156293909 and parameters: {'max_depth': 4, 'learning_rate': 0.06757222746183804, 'n_estimators': 403, 'subsample': 0.7632734690348582, 'colsample_bytree': 0.9226004825800465}. Best is trial 0 with value: 26622.16156293909.
[I 2025-12-30 05:13:49,289] Trial 1 finished with value: 31591.392479654274 and parameters: {'max_depth': 10, 'learning_rate': 0.12756092338623698, 'n_estimators': 476, 'subsample': 0.5649714452834955, 'colsample_bytree': 0.6535681428865586}. Best is trial 0 with value: 26622.16156293909.
[I 2025-12-30 05:13:49,528] Trial 2 finished with value: 21105.562432700513 and parameters: {'max_depth': 9, 'learning_rate': 0.0243487022723984, 'n_estimators': 94, 'subsample': 0.6348049433551908, 'colsample_bytree': 0.6473816218214448}. Best is trial 2 with value: 21105.562432700513.
[I 2025-12-30 05:13:50,321] Trial 3 finished with value: 19573.238242564592 and parameters: {'max_depth': 8, 'learning_rate': 0.046848

Number of tuned numerical features: 6


## 4. Enhanced Mutation Features (global mutation burden / VAF / depth)

In [190]:
X = pd.read_csv("../../data/clinical_train_imputed.csv")
X_eval = pd.read_csv("../../data/clinical_val_imputed.csv")

# Ensure EFFECT dummies exist for molecular train/val
pd.get_dummies(maf_df, columns=["EFFECT"])
pd.get_dummies(maf_eval, columns=["EFFECT"])

Unnamed: 0,ID,CHR,START,END,REF,ALT,GENE,PROTEIN_CHANGE,VAF,DEPTH,EFFECT_ITD,EFFECT_PTD,EFFECT_frameshift_variant,EFFECT_inframe_codon_gain,EFFECT_inframe_codon_loss,EFFECT_non_synonymous_codon,EFFECT_stop_gained,EFFECT_stop_lost
0,KYW961,1,1747229.0,1747229.0,T,C,GNB1,p.K57E,0.2620,485.0,False,False,False,False,False,True,False,False
1,KYW142,1,1747229.0,1747229.0,T,C,GNB1,p.K57E,0.0280,527.0,False,False,False,False,False,True,False,False
2,KYW453,1,1747229.0,1747229.0,T,C,GNB1,p.K57E,0.2920,277.0,False,False,False,False,False,True,False,False
3,KYW982,1,1747229.0,1747229.0,T,C,GNB1,p.K57E,0.0970,821.0,False,False,False,False,False,True,False,False
4,KYW845,1,36932209.0,36932209.0,G,A,CSF3R,p.Q754X,0.4300,358.0,False,False,False,False,False,False,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3084,KYW1077,,,,,,MLL,MLL_PTD,0.4231,,False,True,False,False,False,False,False,False
3085,KYW1084,,,,,,MLL,MLL_PTD,0.0176,,False,True,False,False,False,False,False,False
3086,KYW1082,,,,,,MLL,MLL_PTD,0.2273,,False,True,False,False,False,False,False,False
3087,KYW1085,,,,,,MLL,MLL_PTD,0.2941,,False,True,False,False,False,False,False,False


In [191]:
def compute_mutation_features(
    maf_df: pd.DataFrame,
    X_df: pd.DataFrame,
    top_k_chr: int = 10,
) -> pd.DataFrame:
    """Compute aggregated mutation-level features per patient."""
    maf_df = maf_df.copy()

    if "CHR" not in maf_df.columns:
        raise ValueError(
            "Column 'CHR' is missing from maf_df. "
            "Do not one-hot encode CHR before calling compute_mutation_features."
        )

    # Mutation length and deletion length
    maf_df["LEN"] = maf_df["END"] - maf_df["START"] + 1
    maf_df["DELLEN"] = maf_df["LEN"] - maf_df["REF"].apply(lambda x: len(str(x)))

    # Keep only the top-k chromosomes by frequency for per-chromosome counts
    top_chr = maf_df["CHR"].value_counts().nlargest(top_k_chr).index
    unique_chr = sorted(top_chr)

    # EFFECT dummy columns (if already created upstream)
    effect_dummy_cols = [c for c in maf_df.columns if c.startswith("EFFECT_")]

    agg_dict = {
        "Nmut": ("ID", "size"),
        "VAF_avg": ("VAF", "mean"),
        "VAF_std": ("VAF", "std"),
        "VAF_max": ("VAF", "max"),
        "LEN_avg": ("LEN", "mean"),
        "LEN_max": ("LEN", "max"),
        "DELLEN_sum": ("DELLEN", "sum"),
        "DEPTH_avg": ("DEPTH", "mean"),
        "DEPTH_std": ("DEPTH", "std"),
        "DEPTH_max": ("DEPTH", "max"),
        "DEPTH_min": ("DEPTH", "min"),
        "CHR_nunique": ("CHR", "nunique"),
        "EFFECT_nunique": ("EFFECT", "nunique"),
        "EFFECT_FV_count": ("EFFECT", lambda x: (x == "frameshift_variant").sum()),
        "EFFECT_SG_count": ("EFFECT", lambda x: (x == "stop_gained").sum()),
        "EFFECT_NS_count": ("EFFECT", lambda x: (x == "non_synonymous_codon").sum()),
    }

    # Per-chromosome counts for top_k_chr
    for ch in unique_chr:
        col_name = f"CHR_{ch}_count"
        agg_dict[col_name] = ("CHR", lambda x, val=ch: (x == val).sum())

    # Aggregate EFFECT_* dummy columns (if present)
    for col in effect_dummy_cols:
        new_name = f"{col}_count"
        agg_dict[new_name] = (col, "sum")

    tmp = maf_df.groupby("ID").agg(**agg_dict).reset_index()

    # Fill NaN standard deviations (single mutation per patient)
    for std_col in ["VAF_std", "DEPTH_std"]:
        if std_col in tmp.columns:
            tmp[std_col] = tmp[std_col].fillna(0)

    # Loss-of-function: EFFECT_LOF_count & EFFECT_LOF_ratio
    lof_effect_cols = [
        c
        for c in tmp.columns
        if c.startswith("EFFECT_")
        and c.endswith("_count")
        and (
            "frameshift_variant" in c
            or "stop_gained" in c
        )
    ]

    if len(lof_effect_cols) > 0:
        tmp["EFFECT_LOF_count"] = tmp[lof_effect_cols].sum(axis=1)
    else:
        tmp["EFFECT_LOF_count"] = 0

    tmp["EFFECT_LOF_ratio"] = np.where(
        tmp["Nmut"] > 0,
        tmp["EFFECT_LOF_count"] / tmp["Nmut"],
        0.0,
    )

    # Merge with clinical data
    X_w_mutation = X_df.merge(tmp, on="ID", how="left")

    new_cols = [c for c in tmp.columns if c != "ID"]
    X_w_mutation[new_cols] = X_w_mutation[new_cols].fillna(0)

    return X_w_mutation


# Apply on train / eval
X_w_mutation = compute_mutation_features(maf_df, X)
X_eval_w_mutation = compute_mutation_features(maf_eval, X_eval)

base_mutation_features = [
    "Nmut",
    "VAF_avg",
    "VAF_std",
    "VAF_max",
    "LEN_avg",
    "LEN_max",
    "DELLEN_sum",
    "DEPTH_avg",
    "DEPTH_std",
    "DEPTH_max",
    "DEPTH_min",
    "CHR_nunique",
    "EFFECT_nunique",
    "EFFECT_FV_count",
    "EFFECT_SG_count",
    "EFFECT_NS_count",
    "EFFECT_LOF_count",
    "EFFECT_LOF_ratio",
]

chr_count_cols = [
    c
    for c in X_w_mutation.columns
    if c.startswith("CHR_") and c.endswith("_count")
]

effect_count_cols = [
    c
    for c in X_w_mutation.columns
    if c.startswith("EFFECT_")
    and c.endswith("_count")
    and c
    not in [
        "EFFECT_FV_count",
        "EFFECT_SG_count",
        "EFFECT_NS_count",
        "EFFECT_LOF_count",
    ]
]

mutation_features = base_mutation_features + chr_count_cols + effect_count_cols

# Harmonize columns between train and eval
for col in mutation_features:
    if col not in X_w_mutation.columns:
        X_w_mutation[col] = 0
    if col not in X_eval_w_mutation.columns:
        X_eval_w_mutation[col] = 0

X_w_mutation = X_w_mutation.copy()
X_eval_w_mutation = X_eval_w_mutation.copy()
X_w_mutation[mutation_features] = X_w_mutation[mutation_features]
X_eval_w_mutation[mutation_features] = X_eval_w_mutation[mutation_features]

# Scale mutation features
mutation_scaler = RobustScaler()
print(
    f"Fitting RobustScaler for mutation features "
    f"({len(mutation_features)} features) on training data."
)
X_w_mutation[mutation_features] = mutation_scaler.fit_transform(
    X_w_mutation[mutation_features]
)
X_eval_w_mutation[mutation_features] = mutation_scaler.transform(
    X_eval_w_mutation[mutation_features]
)

Fitting RobustScaler for mutation features (28 features) on training data.


# 5. Cytogenetics Feature Extraction (ISCN parsing)

In [192]:
# Regex patterns for ISCN parsing
_ISCN_EVENT_RE = re.compile(r"(del|dup|inv|ins|i|t|add|der)\s*\(", re.IGNORECASE)
_MONOSOMY_RE = re.compile(r"(?<![pq])-(\d{1,2}|X|Y)(?![pq])", re.IGNORECASE)
_TRISOMY_RE = re.compile(r"(?<![pq])\+(\d{1,2}|X|Y)(?![pq])", re.IGNORECASE)
_CHR_NUM_RE = re.compile(r"(?<![pq])(\d{1,2}|X|Y)(?![pq])", re.IGNORECASE)

_MINUS5_OR_DEL5Q_RE = re.compile(
    r"-(?:5)(?![pq])|del\s*\(\s*5\s*\)\s*\(\s*q", re.IGNORECASE
)
_MINUS7_OR_DEL7Q_RE = re.compile(
    r"-(?:7)(?![pq])|del\s*\(\s*7\s*\)\s*\(\s*q", re.IGNORECASE
)
_PLUS8_RE = re.compile(r"\+8(?![pq])", re.IGNORECASE)
_T_8_21_RE = re.compile(r"t\s*\(\s*8\s*;\s*21\s*\)", re.IGNORECASE)
_INV16_OR_T_16_16_RE = re.compile(
    r"(inv\s*\(\s*16\s*\)|t\s*\(\s*16\s*;\s*16\s*\))", re.IGNORECASE
)
_T_15_17_RE = re.compile(r"t\s*\(\s*15\s*;\s*17\s*\)", re.IGNORECASE)
_STRUCTURAL_RE = re.compile(r"(del|dup|inv|ins|i|t|add|der)\s*\(", re.IGNORECASE)

_INV3_OR_T3_3_RE = re.compile(
    r"(inv\s*\(\s*3\s*\)\s*\(q21q26\)|t\s*\(\s*3\s*;\s*3\s*\)\s*\(q21;q26\))",
    re.IGNORECASE,
)
_T_6_9_RE = re.compile(
    r"t\s*\(\s*6\s*;\s*9\s*\)\s*\(p23;q34\)",
    re.IGNORECASE,
)
_T_9_22_RE = re.compile(
    r"t\s*\(\s*9\s*;\s*22\s*\)\s*\(q34;q11\)",
    re.IGNORECASE,
)
_ABN_17P_RE = re.compile(
    r"(del\s*\(\s*17\s*\)\s*\(\s*p|del\s*\(\s*17p\s*\)|-17(?![pq])|add\s*\(\s*17\s*\)\s*\(\s*p)",
    re.IGNORECASE,
)

_BASELINE_CHR_RE = re.compile(r"^\s*(\d{2})\s*,", re.IGNORECASE)
_NORMAL_KARYO_RE = re.compile(
    r"^\s*46\s*,\s*(XX|XY)\s*(\[\d+\])?\s*$", re.IGNORECASE
)


def _split_clones(karyo: str) -> List[str]:
    """Split ISCN string into clones separated by '/'."""
    return [c.strip() for c in str(karyo).split("/") if c.strip()]


def _extract_metaphases(clone: str) -> int:
    """Extract number of metaphases [n] from a clone."""
    m = re.search(r"\[(\d+)\]", clone)
    return int(m.group(1)) if m else 0


def _count_events(clone: str) -> int:
    """
    Count events in a clone (structural + trisomies + autosomal monosomies, ignoring -Y).
    """
    n_struct = len(_ISCN_EVENT_RE.findall(clone))
    n_mono = len(_MONOSOMY_RE.findall(clone))
    n_tri = len(_TRISOMY_RE.findall(clone))
    n_mono_minusY = len(
        re.findall(r"(?<![pq])-(?:Y)(?![pq])", clone, flags=re.IGNORECASE)
    )
    return n_struct + n_tri + max(n_mono - n_mono_minusY, 0)


def _chromosomes_altered(clone: str) -> int:
    """Number of distinct altered chromosomes (autosomes + sex chromosomes, ignoring Y)."""
    nums = set()
    for m in _MONOSOMY_RE.finditer(clone):
        nums.add(m.group(1).upper())
    for m in _TRISOMY_RE.finditer(clone):
        nums.add(m.group(1).upper())
    for ev in re.finditer(
        r"(del|dup|inv|ins|i|t|add|der)\s*\(([^)]+)\)", clone, flags=re.IGNORECASE
    ):
        for x in re.split(r"[;,\s]+", ev.group(2)):
            if _CHR_NUM_RE.fullmatch(x.strip()):
                nums.add(x.strip().upper())
    nums.discard("Y")
    return len(nums)


def _has_structural(clone: str) -> bool:
    return bool(_STRUCTURAL_RE.search(clone))


def _autosomic_monosomies(clone: str) -> List[int]:
    """List autosomal monosomies in the clone."""
    return [
        int(m.group(1))
        for m in _MONOSOMY_RE.finditer(clone)
        if m.group(1).upper() not in ("X", "Y")
    ]


def _autosomic_trisomies(clone: str) -> List[int]:
    """List autosomal trisomies in the clone."""
    return [
        int(m.group(1))
        for m in _TRISOMY_RE.finditer(clone)
        if m.group(1).upper() not in ("X", "Y")
    ]


def _is_monosomal_karyotype(karyo: str) -> bool:
    """Monosomal karyotype: â‰¥ 2 autosomal monosomies or 1 autosomal monosomy + structural abnormality."""
    for c in _split_clones(karyo):
        autos = _autosomic_monosomies(c)
        has_struct = _has_structural(c)
        if len(autos) >= 2:
            return True
        if len(autos) >= 1 and has_struct:
            return True
    return False


def _is_complex_karyotype(karyo: str) -> bool:
    """Complex karyotype: â‰¥ 3 independent cytogenetic abnormalities."""
    for c in _split_clones(karyo):
        c_wo_minusY = re.sub(
            r"(?<![pq])-(?:Y)(?![pq])", "", c, flags=re.IGNORECASE
        )
        if _count_events(c_wo_minusY) >= 3:
            return True
    return False


def _extract_baseline_chr_count(karyo: str) -> int:
    """Extract baseline chromosome count at start of ISCN string (e.g. '46,XX,...')."""
    if not isinstance(karyo, str):
        return -1
    m = _BASELINE_CHR_RE.match(karyo)
    if not m:
        return -1
    try:
        return int(m.group(1))
    except ValueError:
        return -1


def _clone_flags(clone: str) -> Dict[str, bool]:
    """Compute per-clone flags for prognostic patterns."""
    return {
        "minus5_or_del5q": bool(_MINUS5_OR_DEL5Q_RE.search(clone)),
        "minus7_or_del7q": bool(_MINUS7_OR_DEL7Q_RE.search(clone)),
        "plus8": bool(_PLUS8_RE.search(clone)),
        "t_8_21": bool(_T_8_21_RE.search(clone)),
        "inv16_or_t_16_16": bool(_INV16_OR_T_16_16_RE.search(clone)),
        "t_15_17": bool(_T_15_17_RE.search(clone)),
        "inv3_or_t3_3": bool(_INV3_OR_T3_3_RE.search(clone)),
        "t_6_9": bool(_T_6_9_RE.search(clone)),
        "t_9_22": bool(_T_9_22_RE.search(clone)),
        "abn17p": bool(_ABN_17P_RE.search(clone)),
        "has_structural": _has_structural(clone),
        "events_count": _count_events(clone),
        "chrs_altered": _chromosomes_altered(clone),
        "has_any_abn": bool(
            _ISCN_EVENT_RE.search(clone)
            or _MONOSOMY_RE.search(clone)
            or _TRISOMY_RE.search(clone)
        ),
        "n_monosomies": len(_MONOSOMY_RE.findall(clone)),
        "n_trisomies": len(_TRISOMY_RE.findall(clone)),
        "n_structural_events": len(_ISCN_EVENT_RE.findall(clone)),
    }


def add_cytogenetics_features(
    df_input: pd.DataFrame,
    col: str = "CYTOGENETICS",
) -> pd.DataFrame:
    """Add cytogenetic features derived from ISCN karyotypes."""
    rows = []

    for k in df_input[col]:
        # Handle missing / failed cytogenetics
        if not isinstance(k, str) or not k.strip() or k.strip().lower() in {
            "nan",
            "na",
            "nd",
            "notdone",
            "failed",
            "failure",
        }:
            rows.append(
                {
                    "is_cyto_missing_or_failed": 1,
                    "is_normal_karyotype": 0,
                    "is_abnormal_karyotype": 0,
                    "has_any_abnormality": 0,
                    "n_events": 0,
                    "n_chromosomes_altered": 0,
                    "n_monosomies_total": 0,
                    "n_trisomies_total": 0,
                    "n_structural_events_total": 0,
                    "has_minus5_or_del5q": 0,
                    "has_minus7_or_del7q": 0,
                    "has_plus8": 0,
                    "has_t_8_21": 0,
                    "has_inv16_or_t_16_16": 0,
                    "has_t_15_17": 0,
                    "has_inv3_or_t3_3": 0,
                    "has_t_6_9": 0,
                    "has_t_9_22": 0,
                    "has_abn17p": 0,
                    "is_monosomal_karyotype": 0,
                    "is_complex_karyotype": 0,
                    "eln_like_flag_adverse_cyto": 0,
                    "eln_like_flag_favorable_cyto": 0,
                    "eln_like_flag_intermediate_cyto": 0,
                    "eln_like_risk_cyto": -1,
                    "baseline_chr_count": -1,
                    "is_hypodiploid": 0,
                    "is_hyperdiploid": 0,
                    "is_near_tetraploid": 0,
                    "total_metaphases": 0,
                    "max_clone_size": 0.0,
                    "max_adverse_clone_size": 0.0,
                    "has_small_adverse_subclone": 0,
                    "prop_any_abnormal": 0.0,
                    "prop_adverse_5_7": 0.0,
                    "prop_plus8": 0.0,
                    "prop_favorable_core": 0.0,
                    "n_autosomal_monosomies": 0,
                    "n_autosomal_trisomies": 0,
                    "worst_clone_events": 0,
                    "worst_clone_is_adverse": 0,
                    "has_any_rare_adverse_cyto": 0,
                }
            )
            continue

        clones = _split_clones(k)
        clone_info = []
        total_meta_known = 0

        n_autosomal_mono_tot = 0
        n_autosomal_tri_tot = 0

        for c in clones:
            n_meta = _extract_metaphases(c)
            flags = _clone_flags(c)

            n_autosomal_mono_tot += len(_autosomic_monosomies(c))
            n_autosomal_tri_tot += len(_autosomic_trisomies(c))

            clone_info.append((c, n_meta, flags))
            total_meta_known += n_meta

        any_abn = any(f["has_any_abn"] for _, _, f in clone_info)
        n_events = sum(f["events_count"] for _, _, f in clone_info)
        n_chrs = max([f["chrs_altered"] for _, _, f in clone_info] + [0])

        n_mono_tot = sum(f["n_monosomies"] for _, _, f in clone_info)
        n_tris_tot = sum(f["n_trisomies"] for _, _, f in clone_info)
        n_struct_tot = sum(f["n_structural_events"] for _, _, f in clone_info)

        has_minus5_or_del5q = any(
            f["minus5_or_del5q"] for _, _, f in clone_info
        )
        has_minus7_or_del7q = any(
            f["minus7_or_del7q"] for _, _, f in clone_info
        )
        has_plus8 = any(f["plus8"] for _, _, f in clone_info)
        has_t_8_21 = any(f["t_8_21"] for _, _, f in clone_info)
        has_inv16_or_t_16_16 = any(
            f["inv16_or_t_16_16"] for _, _, f in clone_info
        )
        has_t_15_17 = any(f["t_15_17"] for _, _, f in clone_info)
        has_inv3_or_t3_3 = any(f["inv3_or_t3_3"] for _, _, f in clone_info)
        has_t_6_9 = any(f["t_6_9"] for _, _, f in clone_info)
        has_t_9_22 = any(f["t_9_22"] for _, _, f in clone_info)
        has_abn17p = any(f["abn17p"] for _, _, f in clone_info)

        is_mk = _is_monosomal_karyotype(k)
        is_ck = _is_complex_karyotype(k)

        baseline_chr = _extract_baseline_chr_count(k)
        is_hypo = int(baseline_chr != -1 and baseline_chr < 46)
        is_hyper = int(baseline_chr != -1 and 46 < baseline_chr < 50)
        is_near_tet = int(baseline_chr != -1 and baseline_chr >= 80)

        is_normal = int(bool(_NORMAL_KARYO_RE.match(k)))
        is_abnormal = int(not is_normal and any_abn)
        is_missing = 0

        eln_favorable = bool(
            has_t_8_21 or has_inv16_or_t_16_16 or has_t_15_17
        )
        eln_adverse_basic = bool(
            is_mk or is_ck or has_minus5_or_del5q or has_minus7_or_del7q
        )
        eln_adverse_extended = bool(
            eln_adverse_basic
            or has_inv3_or_t3_3
            or has_t_6_9
            or has_t_9_22
            or has_abn17p
        )

        if is_missing:
            eln_risk = -1
        else:
            if eln_adverse_extended:
                eln_risk = 2
            elif eln_favorable:
                eln_risk = 0
            else:
                eln_risk = 1

        has_any_rare_adverse_cyto = int(
            has_inv3_or_t3_3 or has_t_6_9 or has_t_9_22 or has_abn17p
        )

        # Worst clone by number of events
        worst_clone_events = 0
        worst_clone_is_adverse = 0
        if clone_info:
            max_events = -1
            worst_is_adverse = 0
            for _, _, f in clone_info:
                ev = f["events_count"]
                is_extended_adverse_clone = (
                    f["minus5_or_del5q"]
                    or f["minus7_or_del7q"]
                    or f["inv3_or_t3_3"]
                    or f["t_6_9"]
                    or f["t_9_22"]
                    or f["abn17p"]
                )
                if ev > max_events:
                    max_events = ev
                    worst_is_adverse = int(is_extended_adverse_clone)
            worst_clone_events = int(max_events if max_events >= 0 else 0)
            worst_clone_is_adverse = int(worst_is_adverse)

        def _prop(cond_fn):
            if total_meta_known == 0:
                return 0.0
            pos = sum(
                n_meta
                for _, n_meta, f in clone_info
                if n_meta and cond_fn(f)
            )
            return pos / total_meta_known if total_meta_known else 0.0

        prop_any_abnormal = float(_prop(lambda f: f["has_any_abn"]))
        prop_adverse_5_7 = float(
            _prop(lambda f: f["minus5_or_del5q"] or f["minus7_or_del7q"])
        )
        prop_plus8 = float(_prop(lambda f: f["plus8"]))
        prop_favorable_core = float(
            _prop(lambda f: f["t_8_21"] or f["inv16_or_t_16_16"])
        )

        max_clone_prop = 0.0
        max_adverse_prop = 0.0
        has_small_adverse_subclone = 0

        if total_meta_known > 0:
            for _, n_meta, f in clone_info:
                if not n_meta:
                    continue
                p = n_meta / total_meta_known
                if p > max_clone_prop:
                    max_clone_prop = p

                is_extended_adverse_clone = (
                    f["minus5_or_del5q"]
                    or f["minus7_or_del7q"]
                    or f["inv3_or_t3_3"]
                    or f["t_6_9"]
                    or f["t_9_22"]
                    or f["abn17p"]
                )
                if is_extended_adverse_clone:
                    if p > max_adverse_prop:
                        max_adverse_prop = p
                    if 0.0 < p < 0.3:
                        has_small_adverse_subclone = 1

        rows.append(
            {
                "is_cyto_missing_or_failed": int(is_missing),
                "is_normal_karyotype": int(is_normal),
                "is_abnormal_karyotype": int(is_abnormal),
                "has_any_abnormality": int(any_abn),
                "n_events": int(n_events),
                "n_chromosomes_altered": int(n_chrs),
                "n_monosomies_total": int(n_mono_tot),
                "n_trisomies_total": int(n_tris_tot),
                "n_structural_events_total": int(n_struct_tot),
                "has_minus5_or_del5q": int(has_minus5_or_del5q),
                "has_minus7_or_del7q": int(has_minus7_or_del7q),
                "has_plus8": int(has_plus8),
                "has_t_8_21": int(has_t_8_21),
                "has_inv16_or_t_16_16": int(has_inv16_or_t_16_16),
                "has_t_15_17": int(has_t_15_17),
                "has_inv3_or_t3_3": int(has_inv3_or_t3_3),
                "has_t_6_9": int(has_t_6_9),
                "has_t_9_22": int(has_t_9_22),
                "has_abn17p": int(has_abn17p),
                "is_monosomal_karyotype": int(is_mk),
                "is_complex_karyotype": int(is_ck),
                "eln_like_flag_adverse_cyto": int(eln_adverse_extended),
                "eln_like_flag_favorable_cyto": int(eln_favorable),
                "eln_like_flag_intermediate_cyto": int(eln_risk == 1),
                "eln_like_risk_cyto": int(eln_risk),
                "baseline_chr_count": int(baseline_chr),
                "is_hypodiploid": int(is_hypo),
                "is_hyperdiploid": int(is_hyper),
                "is_near_tetraploid": int(is_near_tet),
                "total_metaphases": int(total_meta_known),
                "max_clone_size": float(max_clone_prop),
                "max_adverse_clone_size": float(max_adverse_prop),
                "has_small_adverse_subclone": int(has_small_adverse_subclone),
                "prop_any_abnormal": float(prop_any_abnormal),
                "prop_adverse_5_7": float(prop_adverse_5_7),
                "prop_plus8": float(prop_plus8),
                "prop_favorable_core": float(prop_favorable_core),
                "n_autosomal_monosomies": int(n_autosomal_mono_tot),
                "n_autosomal_trisomies": int(n_autosomal_tri_tot),
                "worst_clone_events": int(worst_clone_events),
                "worst_clone_is_adverse": int(worst_clone_is_adverse),
                "has_any_rare_adverse_cyto": int(has_any_rare_adverse_cyto),
            }
        )

    features_df = pd.DataFrame(rows, index=df_input.index)
    return pd.concat([df_input.copy(), features_df], axis=1).drop(columns=[col])

In [193]:
# Enrich clinical data with cytogenetics features
X_enhanced = add_cytogenetics_features(X_w_mutation)
X_eval_enhanced = add_cytogenetics_features(X_eval_w_mutation)

cytogenetics_features = [
    "is_cyto_missing_or_failed",
    "is_normal_karyotype",
    "is_abnormal_karyotype",
    "has_any_abnormality",
    "n_events",
    "n_chromosomes_altered",
    "n_monosomies_total",
    "n_trisomies_total",
    "n_structural_events_total",
    "has_minus5_or_del5q",
    "has_minus7_or_del7q",
    "has_plus8",
    "has_t_8_21",
    "has_inv16_or_t_16_16",
    "has_t_15_17",
    "has_inv3_or_t3_3",
    "has_t_6_9",
    "has_t_9_22",
    "has_abn17p",
    "is_monosomal_karyotype",
    "is_complex_karyotype",
    "eln_like_flag_adverse_cyto",
    "eln_like_flag_favorable_cyto",
    "eln_like_flag_intermediate_cyto",
    "eln_like_risk_cyto",
    "baseline_chr_count",
    "is_hypodiploid",
    "is_hyperdiploid",
    "is_near_tetraploid",
    "total_metaphases",
    "max_clone_size",
    "max_adverse_clone_size",
    "has_small_adverse_subclone",
    "prop_any_abnormal",
    "prop_adverse_5_7",
    "prop_plus8",
    "prop_favorable_core",
    "n_autosomal_monosomies",
    "n_autosomal_trisomies",
    "worst_clone_events",
    "worst_clone_is_adverse",
    "has_any_rare_adverse_cyto",
]

# Remove near-constant features (>=95% identical) on training
nearly_constant_features = []
protected = {
    "has_minus5_or_del5q",
    "has_minus7_or_del7q",
    "has_abn17p",
    "has_inv3_or_t3_3",
    "has_t_6_9",
    "has_t_9_22",
    "is_monosomal_karyotype",
    "is_complex_karyotype",
    "has_t_15_17",
    "has_t_8_21",
    "has_inv16_or_t_16_16",
}

for col in cytogenetics_features:
    value_counts = X_enhanced[col].value_counts(dropna=False)
    if len(value_counts) > 0:
        max_proportion = value_counts.iloc[0] / len(X_enhanced)
        if max_proportion >= 0.95 and col not in protected:
            nearly_constant_features.append(col)
            print(
                f"Feature '{col}' is nearly constant "
                f"({max_proportion:.2%} of samples identical) -> dropped."
            )

X_enhanced = X_enhanced.drop(columns=nearly_constant_features)
X_eval_enhanced = X_eval_enhanced.drop(columns=nearly_constant_features)

cytogenetics_features = [
    f for f in cytogenetics_features if f not in nearly_constant_features
]

# Scale remaining cytogenetics features
cytogenetics_scaler = RobustScaler()
print(
    f"Fitting RobustScaler for {len(cytogenetics_features)} cytogenetics features "
    "on training data."
)
X_enhanced[cytogenetics_features] = cytogenetics_scaler.fit_transform(
    X_enhanced[cytogenetics_features]
)
X_eval_enhanced[cytogenetics_features] = cytogenetics_scaler.transform(
    X_eval_enhanced[cytogenetics_features]
)

Feature 'eln_like_flag_favorable_cyto' is nearly constant (99.97% of samples identical) -> dropped.
Feature 'is_near_tetraploid' is nearly constant (99.94% of samples identical) -> dropped.
Feature 'has_small_adverse_subclone' is nearly constant (96.56% of samples identical) -> dropped.
Feature 'prop_favorable_core' is nearly constant (100.00% of samples identical) -> dropped.
Feature 'has_any_rare_adverse_cyto' is nearly constant (97.64% of samples identical) -> dropped.
Fitting RobustScaler for 37 cytogenetics features on training data.


## 6. Merge with survival targets

In [194]:
df_enhanced = X_enhanced.merge(target_df, on="ID", how="left")
df_eval_enhanced = X_eval_enhanced

## 7. Gene-level Features (one-hot per gene)

In [195]:
def add_gene_features(
    df_clinical_enhanced: pd.DataFrame,
    df_molecular: pd.DataFrame,
    gene_list=None,
    top_k: int = 150,
):
    """
    Add one-hot features for genes.
    - If gene_list is None: use top_k most frequent genes in df_molecular
      (frequency = number of distinct patients with a mutation in the gene).
    - Otherwise: use the provided gene_list (for consistent train/val columns).
    """
    if gene_list is None:
        gene_counts = (
            df_molecular[["ID", "GENE"]]
            .drop_duplicates()["GENE"]
            .value_counts()
        )
        gene_list = gene_counts.nlargest(top_k).index.tolist()
        print(f"Number of genes used (top {top_k}): {len(gene_list)}")

    df_filtered = df_molecular[df_molecular["GENE"].isin(gene_list)].copy()

    gene_pivot = pd.crosstab(df_filtered["ID"], df_filtered["GENE"])

    for g in gene_list:
        if g not in gene_pivot.columns:
            gene_pivot[g] = 0

    gene_pivot = gene_pivot[gene_list]

    gene_pivot = (gene_pivot > 0).astype(int)
    gene_pivot.columns = [f"Gene_{col}" for col in gene_pivot.columns]

    df_final = df_clinical_enhanced.merge(gene_pivot, on="ID", how="left")

    new_cols = [c for c in df_final.columns if c.startswith("Gene_")]
    df_final[new_cols] = df_final[new_cols].fillna(0)

    return df_final, gene_list


mol_train_raw = pd.read_csv("../../data/molecular_train.csv")
mol_val_raw = pd.read_csv("../../data/molecular_val.csv")

df_train_pivot, gene_list_ref = add_gene_features(
    df_clinical_enhanced=df_enhanced,
    df_molecular=mol_train_raw,
    gene_list=None,
    top_k=150,
)

df_val_pivot, _ = add_gene_features(
    df_clinical_enhanced=df_eval_enhanced,
    df_molecular=mol_val_raw,
    gene_list=gene_list_ref,
)

# Harmonize columns between train and val
gene_cols = [c for c in df_train_pivot.columns if c.startswith("Gene_")]

train_clinical_cols = [
    c for c in df_train_pivot.columns if not c.startswith("Gene_")
]
val_clinical_cols = [
    c for c in df_val_pivot.columns if not c.startswith("Gene_")
]

train_clinical_cols_for_common = [
    c for c in train_clinical_cols if c not in ["OS_YEARS", "OS_STATUS"]
]

common_clinical = sorted(
    set(train_clinical_cols_for_common).intersection(val_clinical_cols)
)

final_features_train = common_clinical + gene_cols + ["OS_YEARS", "OS_STATUS"]
final_features_val = common_clinical + gene_cols

df_train_pivot = df_train_pivot[final_features_train]
df_val_pivot = df_val_pivot[final_features_val]

print(f"Final train columns (with targets): {len(final_features_train)}")
print(f"Final val columns (without targets): {len(final_features_val)}")
print(f"Train shape: {df_train_pivot.shape}")
print(f"Val shape:   {df_val_pivot.shape}")
print(f"'OS_YEARS' in train: {'OS_YEARS' in df_train_pivot.columns}")
print(f"'OS_STATUS' in train: {'OS_STATUS' in df_train_pivot.columns}")

Number of genes used (top 150): 124
Final train columns (with targets): 204
Final val columns (without targets): 202
Train shape: (3173, 204)
Val shape:   (1193, 202)
'OS_YEARS' in train: True
'OS_STATUS' in train: True


# 8. Molecular Risk Score (regularized Cox model on gene features)

In [196]:
gene_cols = [c for c in df_train_pivot.columns if c.startswith("Gene_")]

cox_cols_train = ["OS_YEARS", "OS_STATUS"] + gene_cols
cox_df_train = df_train_pivot[cox_cols_train].copy()
cox_df_train["OS_STATUS"] = cox_df_train["OS_STATUS"].astype(int)

gene_mat = cox_df_train[gene_cols]
prevalence = gene_mat.mean(axis=0)

min_prop = 0.005
max_prop = 0.99

valid_genes = prevalence[
    (prevalence >= min_prop) & (prevalence <= max_prop)
].index.tolist()
print(f"Initial number of genes: {len(gene_cols)}")
print(f"Number of genes kept: {len(valid_genes)}")

cox_df_train = cox_df_train[["OS_YEARS", "OS_STATUS"] + valid_genes].copy()

X_all = cox_df_train[valid_genes].to_numpy(dtype=float)
y_time_all = cox_df_train["OS_YEARS"].to_numpy(dtype=float)
y_event_all = cox_df_train["OS_STATUS"].to_numpy(dtype=int)

outer_splits = 5
inner_splits = 3
n_trials_nested = 30

outer_cv = KFold(n_splits=outer_splits, shuffle=True, random_state=42)
inner_cv = KFold(n_splits=inner_splits, shuffle=True, random_state=43)

outer_c_indices = []
best_params_per_outer_fold = []

for fold_idx, (train_outer_idx, val_outer_idx) in enumerate(
    outer_cv.split(X_all), start=1
):
    print(f"\n=== Outer fold {fold_idx}/{outer_splits} ===")

    X_train_outer = X_all[train_outer_idx]
    y_time_train_outer = y_time_all[train_outer_idx]
    y_event_train_outer = y_event_all[train_outer_idx]

    X_val_outer = X_all[val_outer_idx]
    y_time_val_outer = y_time_all[val_outer_idx]
    y_event_val_outer = y_event_all[val_outer_idx]

    def objective(trial):
        penalizer = trial.suggest_float("penalizer", 1e-4, 10.0, log=True)
        l1_ratio = trial.suggest_float("l1_ratio", 0.0, 1.0)

        inner_c_indices = []

        for train_inner_idx, val_inner_idx in inner_cv.split(X_train_outer):
            X_train_inner = X_train_outer[train_inner_idx]
            y_time_train_inner = y_time_train_outer[train_inner_idx]
            y_event_train_inner = y_event_train_outer[train_inner_idx]

            X_val_inner = X_train_outer[val_inner_idx]
            y_time_val_inner = y_time_train_outer[val_inner_idx]
            y_event_val_inner = y_event_train_outer[val_inner_idx]

            scaler_inner = StandardScaler()
            X_train_inner_scaled = scaler_inner.fit_transform(X_train_inner)
            X_val_inner_scaled = scaler_inner.transform(X_val_inner)

            train_df_inner = pd.DataFrame(
                X_train_inner_scaled, columns=valid_genes
            )
            train_df_inner["OS_YEARS"] = y_time_train_inner
            train_df_inner["OS_STATUS"] = y_event_train_inner

            val_df_inner = pd.DataFrame(
                X_val_inner_scaled, columns=valid_genes
            )
            val_df_inner["OS_YEARS"] = y_time_val_inner
            val_df_inner["OS_STATUS"] = y_event_val_inner

            cph_inner = CoxPHFitter(
                penalizer=penalizer,
                l1_ratio=l1_ratio,
            )
            cph_inner.fit(
                train_df_inner,
                duration_col="OS_YEARS",
                event_col="OS_STATUS",
            )

            risk_scores_inner = -cph_inner.predict_partial_hazard(
                val_df_inner[valid_genes]
            ).values.ravel()

            c_idx_inner = concordance_index(
                y_time_val_inner,
                risk_scores_inner,
                y_event_val_inner,
            )
            inner_c_indices.append(c_idx_inner)

        return float(np.mean(inner_c_indices))

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials_nested)

    print("  Best inner CV C-index:", study.best_value)
    print("  Best params:", study.best_params)
    best_params_per_outer_fold.append(study.best_params)

    scaler_outer = StandardScaler()
    X_train_outer_scaled = scaler_outer.fit_transform(X_train_outer)
    X_val_outer_scaled = scaler_outer.transform(X_val_outer)

    train_df_outer = pd.DataFrame(
        X_train_outer_scaled, columns=valid_genes
    )
    train_df_outer["OS_YEARS"] = y_time_train_outer
    train_df_outer["OS_STATUS"] = y_event_train_outer

    val_df_outer = pd.DataFrame(
        X_val_outer_scaled, columns=valid_genes
    )
    val_df_outer["OS_YEARS"] = y_time_val_outer
    val_df_outer["OS_STATUS"] = y_event_val_outer

    cph_outer = CoxPHFitter(
        penalizer=study.best_params["penalizer"],
        l1_ratio=study.best_params["l1_ratio"],
    )
    cph_outer.fit(
        train_df_outer,
        duration_col="OS_YEARS",
        event_col="OS_STATUS",
    )

    risk_scores_outer = -cph_outer.predict_partial_hazard(
        val_df_outer[valid_genes]
    ).values.ravel()

    c_idx_outer = concordance_index(
        y_time_val_outer,
        risk_scores_outer,
        y_event_val_outer,
    )
    outer_c_indices.append(c_idx_outer)
    print(f"  Outer fold {fold_idx} C-index (nested): {c_idx_outer:.4f}")

print("\n=== Nested CV results ===")
print("Mean outer C-index:", np.mean(outer_c_indices))
print("Std outer C-index: ", np.std(outer_c_indices))
print("\nBest params per outer fold:")
for i, p in enumerate(best_params_per_outer_fold, start=1):
    print(f"  Fold {i}: {p}")

Initial number of genes: 124
Number of genes kept: 68

=== Outer fold 1/5 ===


[I 2025-12-30 05:56:16,650] A new study created in memory with name: no-name-0f661d5c-38a2-491d-bbb8-3573d5d1a617
[I 2025-12-30 05:56:28,301] Trial 0 finished with value: 0.6824534845726165 and parameters: {'penalizer': 0.01536824200947512, 'l1_ratio': 0.7782742016902697}. Best is trial 0 with value: 0.6824534845726165.
[I 2025-12-30 05:56:39,340] Trial 1 finished with value: 0.686926795019057 and parameters: {'penalizer': 0.1999952052647711, 'l1_ratio': 0.09735397665289158}. Best is trial 1 with value: 0.686926795019057.
[I 2025-12-30 05:56:49,872] Trial 2 finished with value: 0.6795472313763972 and parameters: {'penalizer': 0.029401626649746306, 'l1_ratio': 0.23212781376410108}. Best is trial 1 with value: 0.686926795019057.
[I 2025-12-30 05:56:59,877] Trial 3 finished with value: 0.6760481501133166 and parameters: {'penalizer': 0.6270661567172766, 'l1_ratio': 0.6626157144277959}. Best is trial 1 with value: 0.686926795019057.
[I 2025-12-30 05:57:10,732] Trial 4 finished with value: 

  Best inner CV C-index: 0.6890456646282148
  Best params: {'penalizer': 0.2956103925012571, 'l1_ratio': 0.48743822455290725}


[I 2025-12-30 06:01:18,364] A new study created in memory with name: no-name-7797e356-9be5-4633-89ff-567a35b5b330


  Outer fold 1 C-index (nested): 0.6777

=== Outer fold 2/5 ===


[I 2025-12-30 06:01:29,478] Trial 0 finished with value: 0.6852488991935713 and parameters: {'penalizer': 0.10070010141163016, 'l1_ratio': 0.4303894438673933}. Best is trial 0 with value: 0.6852488991935713.
[I 2025-12-30 06:01:40,374] Trial 1 finished with value: 0.689129737061168 and parameters: {'penalizer': 0.0995782952526039, 'l1_ratio': 0.2880797512606337}. Best is trial 1 with value: 0.689129737061168.
[I 2025-12-30 06:01:44,934] Trial 2 finished with value: 0.6849896899727314 and parameters: {'penalizer': 0.0004110223660022562, 'l1_ratio': 0.09802460326427587}. Best is trial 1 with value: 0.689129737061168.
[I 2025-12-30 06:01:53,169] Trial 3 finished with value: 0.6772474365507598 and parameters: {'penalizer': 7.868275247539285, 'l1_ratio': 0.7735316567139623}. Best is trial 1 with value: 0.689129737061168.
[I 2025-12-30 06:02:02,853] Trial 4 finished with value: 0.6776528308852118 and parameters: {'penalizer': 8.02447280565816, 'l1_ratio': 0.07868024338111235}. Best is trial 

  Best inner CV C-index: 0.6900844638189702
  Best params: {'penalizer': 0.025898639224771734, 'l1_ratio': 0.9944291398683284}


[I 2025-12-30 06:06:15,051] A new study created in memory with name: no-name-e8eece94-39cc-4652-9155-aaccc144a311


  Outer fold 2 C-index (nested): 0.7051

=== Outer fold 3/5 ===


[I 2025-12-30 06:06:23,930] Trial 0 finished with value: 0.678885288397633 and parameters: {'penalizer': 0.0031222316081863837, 'l1_ratio': 0.6481142001568531}. Best is trial 0 with value: 0.678885288397633.
[I 2025-12-30 06:06:32,473] Trial 1 finished with value: 0.6762413764981549 and parameters: {'penalizer': 6.5309844406113555, 'l1_ratio': 0.7236655576055052}. Best is trial 0 with value: 0.678885288397633.
[I 2025-12-30 06:06:41,543] Trial 2 finished with value: 0.6762626709072151 and parameters: {'penalizer': 4.870689386848968, 'l1_ratio': 0.3744826890356622}. Best is trial 0 with value: 0.678885288397633.
[I 2025-12-30 06:06:52,633] Trial 3 finished with value: 0.6830821996574766 and parameters: {'penalizer': 0.16109060280311957, 'l1_ratio': 0.5739052631500162}. Best is trial 3 with value: 0.6830821996574766.
[I 2025-12-30 06:06:58,968] Trial 4 finished with value: 0.6768595701416769 and parameters: {'penalizer': 0.00042189793015123426, 'l1_ratio': 0.7106105950747217}. Best is tr

  Best inner CV C-index: 0.6890996715598928
  Best params: {'penalizer': 0.12803792195163669, 'l1_ratio': 0.1493100736360229}


[I 2025-12-30 06:11:13,469] A new study created in memory with name: no-name-12deada2-b18a-4129-a471-8edb0c30c551


  Outer fold 3 C-index (nested): 0.6879

=== Outer fold 4/5 ===


[I 2025-12-30 06:11:22,129] Trial 0 finished with value: 0.680542281180382 and parameters: {'penalizer': 0.005684641319750673, 'l1_ratio': 0.28461326179161284}. Best is trial 0 with value: 0.680542281180382.
[I 2025-12-30 06:11:31,495] Trial 1 finished with value: 0.6753822835070169 and parameters: {'penalizer': 2.9590234578704435, 'l1_ratio': 0.5406025708097975}. Best is trial 0 with value: 0.680542281180382.
[I 2025-12-30 06:11:41,037] Trial 2 finished with value: 0.6753822835070169 and parameters: {'penalizer': 1.8375278584966914, 'l1_ratio': 0.8873210253749371}. Best is trial 0 with value: 0.680542281180382.
[I 2025-12-30 06:11:52,244] Trial 3 finished with value: 0.6842737537202416 and parameters: {'penalizer': 0.24085875053280023, 'l1_ratio': 0.6333622380189198}. Best is trial 3 with value: 0.6842737537202416.
[I 2025-12-30 06:12:03,680] Trial 4 finished with value: 0.67682592294384 and parameters: {'penalizer': 0.46738588730767716, 'l1_ratio': 0.638533906146218}. Best is trial 3

  Best inner CV C-index: 0.6876287140996555
  Best params: {'penalizer': 0.09951072159977802, 'l1_ratio': 0.10496828162184102}


[I 2025-12-30 06:17:20,695] A new study created in memory with name: no-name-bc08fece-38a4-456e-a669-79eb4f2cddde


  Outer fold 4 C-index (nested): 0.7099

=== Outer fold 5/5 ===


[I 2025-12-30 06:17:36,934] Trial 0 finished with value: 0.6880380626914268 and parameters: {'penalizer': 0.014141147414561675, 'l1_ratio': 0.27384224796016354}. Best is trial 0 with value: 0.6880380626914268.
[I 2025-12-30 06:17:51,060] Trial 1 finished with value: 0.6858534697504745 and parameters: {'penalizer': 0.0014359377754233084, 'l1_ratio': 0.9095140015970209}. Best is trial 0 with value: 0.6880380626914268.
[I 2025-12-30 06:18:08,445] Trial 2 finished with value: 0.6902939031286558 and parameters: {'penalizer': 0.010224072815257482, 'l1_ratio': 0.732853823839197}. Best is trial 2 with value: 0.6902939031286558.
[I 2025-12-30 06:18:20,471] Trial 3 finished with value: 0.6852404967672125 and parameters: {'penalizer': 0.0020358201509835597, 'l1_ratio': 0.2571399073604971}. Best is trial 2 with value: 0.6902939031286558.
[I 2025-12-30 06:18:35,612] Trial 4 finished with value: 0.6762510674605883 and parameters: {'penalizer': 5.170869633707368, 'l1_ratio': 0.4730062771920738}. Best

  Best inner CV C-index: 0.6950036786962605
  Best params: {'penalizer': 0.024990429132179057, 'l1_ratio': 0.8230930472739336}
  Outer fold 5 C-index (nested): 0.6850

=== Nested CV results ===
Mean outer C-index: 0.6931218383711261
Std outer C-index:  0.012307645362880225

Best params per outer fold:
  Fold 1: {'penalizer': 0.2956103925012571, 'l1_ratio': 0.48743822455290725}
  Fold 2: {'penalizer': 0.025898639224771734, 'l1_ratio': 0.9944291398683284}
  Fold 3: {'penalizer': 0.12803792195163669, 'l1_ratio': 0.1493100736360229}
  Fold 4: {'penalizer': 0.09951072159977802, 'l1_ratio': 0.10496828162184102}
  Fold 5: {'penalizer': 0.024990429132179057, 'l1_ratio': 0.8230930472739336}


## 9. Final Cox model on all training data (gene features only)

In [197]:
best_penalizers = [p["penalizer"] for p in best_params_per_outer_fold]
best_l1_ratios = [p["l1_ratio"] for p in best_params_per_outer_fold]

final_penalizer = float(np.median(best_penalizers))
final_l1_ratio = float(np.median(best_l1_ratios))

print("\nFinal hyperparameters (median over outer folds):")
print(f"  penalizer = {final_penalizer:.4g}")
print(f"  l1_ratio  = {final_l1_ratio:.4f}")

scaler_final = StandardScaler()
X_all_scaled = scaler_final.fit_transform(X_all)

cox_df_train_scaled = pd.DataFrame(X_all_scaled, columns=valid_genes)
cox_df_train_scaled["OS_YEARS"] = y_time_all
cox_df_train_scaled["OS_STATUS"] = y_event_all

cph_best = CoxPHFitter(
    penalizer=final_penalizer,
    l1_ratio=final_l1_ratio,
)
cph_best.fit(
    cox_df_train_scaled,
    duration_col="OS_YEARS",
    event_col="OS_STATUS",
)

print("\nFinal CoxPH model (cph_best) trained on full training set.")


Final hyperparameters (median over outer folds):
  penalizer = 0.09951
  l1_ratio  = 0.4874

Final CoxPH model (cph_best) trained on full training set.


## 10. Compute Molecular Risk Score and Quartiles

In [204]:
# Risk score on training data
X_train_genes = cox_df_train_scaled[valid_genes]
raw_score = cph_best.predict_log_partial_hazard(X_train_genes)

cindex_pos = concordance_index(
    cox_df_train_scaled["OS_YEARS"].values,
    raw_score.values,
    cox_df_train_scaled["OS_STATUS"].values,
)

cindex_neg = concordance_index(
    cox_df_train_scaled["OS_YEARS"].values,
    (-raw_score).values,
    cox_df_train_scaled["OS_STATUS"].values,
)

if cindex_neg > cindex_pos:
    print(
        f"Inverting score orientation "
        f"(C-index {cindex_pos:.3f} -> {cindex_neg:.3f})"
    )
    genetic_risk = -raw_score
    cindex_final = cindex_neg
else:
    print(f"Score orientation kept (C-index {cindex_pos:.3f})")
    genetic_risk = raw_score
    cindex_final = cindex_pos

print("Final C-index of genetic risk score (train):", cindex_final)

df_train_pivot["MRS"] = genetic_risk.values

# Risk score for validation data (scaled with same scaler)
features_used = valid_genes
X_val_features = df_val_pivot[features_used].to_numpy(dtype=float)
X_val_scaled = scaler_final.transform(X_val_features)
X_val_scaled_df = pd.DataFrame(X_val_scaled, columns=features_used)

df_val_pivot["MRS"] = cph_best.predict_log_partial_hazard(
    X_val_scaled_df
)

# Quartile-based risk groups (same cutpoints for train/val)
qs = np.quantile(df_train_pivot["MRS"], [0, 0.25, 0.5, 0.75, 1.0])

df_train_pivot["RiskQuartile"] = pd.cut(
    df_train_pivot["MRS"],
    bins=qs,
    labels=["Q1 (lowest)", "Q2", "Q3", "Q4 (highest)"],
    include_lowest=True,
)
df_val_pivot["RiskQuartile"] = pd.cut(
    df_val_pivot["MRS"],
    bins=qs,
    labels=["Q1 (lowest)", "Q2", "Q3", "Q4 (highest)"],
    include_lowest=True,
)

df_train_pivot = pd.get_dummies(
    df_train_pivot,
    columns=["RiskQuartile"],
    drop_first=False,
)
df_val_pivot = pd.get_dummies(
    df_val_pivot,
    columns=["RiskQuartile"],
    drop_first=False,
)

Inverting score orientation (C-index 0.305 -> 0.695)
Final C-index of genetic risk score (train): 0.6952588657675822


# 11. Additional Feature Engineering

In [205]:
# 11. Additional Feature Engineering
# Goal: keep only a small number of interpretable, well-motivated feature families


# 11.1 Cytogenetic risk score (single compact score)
df_train_pivot["cyto_risk_score"] = (
    3 * df_train_pivot["is_monosomal_karyotype"]
    + 3 * df_train_pivot["is_complex_karyotype"]
    + 2 * df_train_pivot["has_minus7_or_del7q"]
    + 2 * df_train_pivot["has_minus5_or_del5q"]
    + 1 * df_train_pivot["has_plus8"]
)

df_val_pivot["cyto_risk_score"] = (
    3 * df_val_pivot["is_monosomal_karyotype"]
    + 3 * df_val_pivot["is_complex_karyotype"]
    + 2 * df_val_pivot["has_minus7_or_del7q"]
    + 2 * df_val_pivot["has_minus5_or_del5q"]
    + 1 * df_val_pivot["has_plus8"]
)


# 11.2 Gene-level composite scores
df_train_pivot["risk_score_high_genes"] = (
    df_train_pivot["Gene_TP53"]
    + df_train_pivot["Gene_ASXL1"]
    + df_train_pivot["Gene_RUNX1"]
)

df_train_pivot["risk_score_favorable_genes"] = (
    df_train_pivot["Gene_NPM1"] + df_train_pivot["Gene_CEBPA"]
)

df_val_pivot["risk_score_high_genes"] = (
    df_val_pivot["Gene_TP53"]
    + df_val_pivot["Gene_ASXL1"]
    + df_val_pivot["Gene_RUNX1"]
)

df_val_pivot["risk_score_favorable_genes"] = (
    df_val_pivot["Gene_NPM1"] + df_val_pivot["Gene_CEBPA"]
)


# 11.3 Functional gene class counts
df_train_pivot["n_splicing_mut"] = df_train_pivot[
    ["Gene_U2AF1", "Gene_SRSF2", "Gene_SF3B1", "Gene_ZRSR2"]
].sum(axis=1)

df_train_pivot["n_signaling_mut"] = df_train_pivot[
    ["Gene_NRAS", "Gene_KRAS", "Gene_JAK2", "Gene_CBL"]
].sum(axis=1)

df_val_pivot["n_splicing_mut"] = df_val_pivot[
    ["Gene_U2AF1", "Gene_SRSF2", "Gene_SF3B1", "Gene_ZRSR2"]
].sum(axis=1)

df_val_pivot["n_signaling_mut"] = df_val_pivot[
    ["Gene_NRAS", "Gene_KRAS", "Gene_JAK2", "Gene_CBL"]
].sum(axis=1)


# 11.4 A single biologically-motivated interaction term
df_train_pivot["TP53_complex_interaction"] = (
    df_train_pivot["Gene_TP53"] * df_train_pivot["is_complex_karyotype"]
)

df_val_pivot["TP53_complex_interaction"] = (
    df_val_pivot["Gene_TP53"] * df_val_pivot["is_complex_karyotype"]
)


# 11.5 Simple & interpretable clinical ratios
df_train_pivot["ANC_WBC_ratio"] = df_train_pivot["ANC"] / (df_train_pivot["WBC"] + 1)
df_train_pivot["BLAST_WBC_ratio"] = df_train_pivot["BM_BLAST"] / (
    df_train_pivot["WBC"] + 1
)

df_val_pivot["ANC_WBC_ratio"] = df_val_pivot["ANC"] / (df_val_pivot["WBC"] + 1)
df_val_pivot["BLAST_WBC_ratio"] = df_val_pivot["BM_BLAST"] / (
    df_val_pivot["WBC"] + 1
)


# 11.6 Subclonality proxies
df_train_pivot["subclonality"] = df_train_pivot["VAF_std"] / (
    df_train_pivot["VAF_avg"] + 1e-6
)

df_val_pivot["subclonality"] = df_val_pivot["VAF_std"] / (
    df_val_pivot["VAF_avg"] + 1e-6
)

## 12. VAF Entropy (distribution of clone sizes)

In [200]:
def compute_vaf_entropy(df_mut: pd.DataFrame) -> pd.DataFrame:
    """
    Compute Shannon entropy of VAF distribution per patient.
    """

    def entropy_from_vaf(vaf_list):
        vaf_arr = np.array(vaf_list, dtype=float)
        total = vaf_arr.sum()
        if total <= 0:
            return 0.0
        p = vaf_arr / total
        p = p[p > 0]
        return float(-np.sum(p * np.log(p)))

    entropy_per_patient = (
        df_mut.groupby("ID")["VAF"]
        .apply(entropy_from_vaf)
        .reset_index()
        .rename(columns={"VAF": "vaf_entropy"})
    )

    return entropy_per_patient


entropy_train = compute_vaf_entropy(maf_df)
entropy_eval = compute_vaf_entropy(maf_eval)

df_train_pivot = df_train_pivot.merge(entropy_train, on="ID", how="left")
df_val_pivot = df_val_pivot.merge(entropy_eval, on="ID", how="left")

df_train_pivot["vaf_entropy"] = df_train_pivot["vaf_entropy"].fillna(0)
df_val_pivot["vaf_entropy"] = df_val_pivot["vaf_entropy"].fillna(0)


## 13. Additional ratios and transformations

In [201]:
def safe_ratio(num: pd.Series, den: pd.Series) -> pd.Series:
    num = num.astype(float)
    den = den.astype(float)
    res = num / den
    res[~np.isfinite(res)] = np.nan
    return res


# Replace sentinel HB == -1
for df_ in [df_train_pivot, df_val_pivot]:
    df_.loc[df_["HB"] == -1, "HB"] = np.nan

hb_median = df_train_pivot["HB"].median()

df_train_pivot["HB"] = df_train_pivot["HB"].fillna(hb_median)
df_val_pivot["HB"] = df_val_pivot["HB"].fillna(hb_median)


# Platelet / Hemoglobin ratio
for df_ in [df_train_pivot, df_val_pivot]:
    df_["PLT_HB_ratio"] = safe_ratio(df_["PLT"], df_["HB"] + 1)


# Log-transform skewed variables
for df_ in [df_train_pivot, df_val_pivot]:
    for col in [
        "WBC",
        "ANC",
        "PLT",
        "MONOCYTES",
        "BM_BLAST",
        "Nmut",
        "DEPTH_avg",
    ]:
        if col in df_.columns:
            df_[f"log1p_{col}"] = np.log1p(df_[col].clip(lower=0))


# Drop extremely rare mutation-effect features
mutation_cols = [
    c
    for c in df_train_pivot.columns
    if c.startswith("EFFECT_") and c.endswith("_count")
]

low_freq_cols = [c for c in mutation_cols if (df_train_pivot[c] > 0).mean() < 0.01]

df_train_pivot = df_train_pivot.drop(columns=low_freq_cols)
df_val_pivot = df_val_pivot.drop(columns=[c for c in low_freq_cols if c in df_val_pivot.columns])


print("EFFECT_LOF_count in train:", "EFFECT_LOF_count" in df_train_pivot.columns)
print([c for c in df_train_pivot.columns if "EFFECT" in c])


# Global mutation burden / clonal complexity score
for df_ in [df_train_pivot, df_val_pivot]:
    df_["mutation_burden_score"] = (
        0.5 * df_["Nmut"]
        + 1.5 * df_["EFFECT_LOF_ratio"]
        + 1.0 * df_["vaf_entropy"]
    )


# Log-transform selected cytogenetic counts
for df_ in [df_train_pivot, df_val_pivot]:
    for col in [
        "n_events",
        "n_chromosomes_altered",
        "n_monosomies_total",
        "n_trisomies_total",
    ]:
        if col in df_.columns:
            df_[f"log1p_{col}"] = np.log1p(df_[col])


# Final missing-value sanity check
na_counts = df_train_pivot.isna().sum()
print("Columns with missing values in train:")
print(na_counts[na_counts > 0])


EFFECT_LOF_count in train: False
['EFFECT_FV_count', 'EFFECT_LOF_ratio', 'EFFECT_NS_count', 'EFFECT_SG_count', 'EFFECT_nunique']
Columns with missing values in train:
Series([], dtype: int64)


## 14. Final scaling of continuous features

In [202]:
all_cols = df_train_pivot.columns.tolist()

binary_like = [
    c
    for c in all_cols
    if c.startswith("Gene_")
    or c.startswith("RiskQuartile_")
    or df_train_pivot[c].dropna().isin([0, 1]).all()
]

target_cols = ["OS_YEARS", "OS_STATUS"]

cont_cols = [
    c for c in all_cols if c not in binary_like + target_cols + ["ID"]
]

scaler_final_features = RobustScaler()
df_train_pivot[cont_cols] = scaler_final_features.fit_transform(
    df_train_pivot[cont_cols]
)
df_val_pivot[cont_cols] = scaler_final_features.transform(
    df_val_pivot[cont_cols]
)

## 15. Save final enhanced datasets

In [203]:
df_train_pivot.to_csv("../../data/train_enhanced.csv", index=False)
df_val_pivot.to_csv("../../data/eval_enhanced.csv", index=False)