# Infection_grade: End-to-End ML Pipeline

This notebook implements a complete, reproducible pipeline for predicting `Infection_grade` (binary outcome) using static and dynamic features. It includes:

- data loading and patient-level stratified split
- dynamic time-series aggregation (Day -15 .. +2)
- a no-leak preprocessing pipeline (imputation, encoding, scaling)
- Optuna hyperparameter optimization with 5-fold cross-validation and pruning
- final training on the full training set and evaluation on an independent test set
- bootstrap confidence intervals for metrics
- SHAP explanations and visualizations

**Before running:** update the `CONFIG` cell with paths to your data and adjust feature lists.

本笔记本实现了一个完整的、可复现的流程，用于使用静态和动态特征预测 `Infection_grade`（二元结果）。它包括：
- 数据加载和患者层面的分层分割
- 动态时间序列聚合（第 -15 天至第 +2 天）
- 无泄漏预处理流程（插补、编码、缩放）
- 使用 5 折交叉验证和剪枝进行 Optuna 超参数优化
- 在完整训练集上进行最终训练，并在独立测试集上进行评估
- 指标的自助法置信区间
- SHAP 解释和可视化
**运行前：**请更新 `CONFIG` 单元格，填写数据路径并调整特征列表。


# Infection_grade: End-to-End ML Pipeline with Optuna HPO and SHAP
**目的**：为 `Infection_grade`（二分类）构建无泄漏、可复现的端到端机器学习流程。  
内容包括数据分割、特征工程（静态 + 动态）、Optuna 超参数优化（5 折 CV）、最终训练、独立测试评估与 SHAP 可解释性分析。

**注意**：在运行前请修改 `CONFIG` 区域的文件路径以匹配你本地环境。

In [None]:
# Cell: Dependencies & Config
# Install required packages if needed (uncomment to run in a fresh environment)
# !pip install pandas numpy scikit-learn lightgbm optuna shap matplotlib seaborn joblib

import os
import random
import json
from pathlib import Path
from typing import Dict, Any, Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    roc_auc_score, average_precision_score, precision_score, recall_score,
    f1_score, brier_score_loss, precision_recall_curve, roc_curve
)

import lightgbm as lgb
import optuna
import shap
import joblib

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# CONFIG - update paths and parameters as needed
CONFIG = {
    "STATIC_CSV": "data/static/encoded_standardized.csv",
    "DYNAMIC_DIR": "data/dynamic/processed_standardized",
    "OUTPUT_DIR": "artifacts/infection_pipeline",
    "TEST_SIZE": 0.30,
    "N_FOLDS": 5,
    "OPTUNA_TRIALS": 40,
    "OBS_START": -15,
    "OBS_END": 2,
    "N_BOOTSTRAP": 1000,
    "N_SHAP_SAMPLES": 300
}

Path(CONFIG["OUTPUT_DIR"]).mkdir(parents=True, exist_ok=True)

print('CONFIG loaded. Output dir:', CONFIG['OUTPUT_DIR'])

## Utilities: label binarization, dynamic aggregation, stratified patient split, bootstrap CI

These helper functions are used throughout the notebook. Dynamic aggregation converts each patient's time series CSV into aggregated numeric features (mean/std/min/max/count/slope/auc/last).

In [None]:
def binarize_label(df: pd.DataFrame, col: str = "Infection_grade", threshold: int = 2) -> pd.Series:
    """Binarize toxicity label: <= threshold -> 0, > threshold -> 1."""
    return (df[col].astype(float) > threshold).astype(int)


def aggregate_dynamic_for_patient(csv_path: str, obs_start: int, obs_end: int) -> Dict[str, float]:
    """Aggregate a single patient's dynamic CSV into features.
    Returns dict of features (may be empty if file missing).
    """
    if not os.path.exists(csv_path):
        return {}
    df = pd.read_csv(csv_path)
    if 'Day' not in df.columns:
        df = df.rename(columns={df.columns[0]: 'Day'})
    df = df[(df['Day'] >= obs_start) & (df['Day'] <= obs_end)]
    features = {}
    for col in df.columns:
        if col == 'Day':
            continue
        s = pd.to_numeric(df[col], errors='coerce')
        features[f"{col}_mean"] = s.mean()
        features[f"{col}_std"] = s.std()
        features[f"{col}_min"] = s.min()
        features[f"{col}_max"] = s.max()
        features[f"{col}_count"] = s.count()
        # last
        try:
            features[f"{col}_last"] = s.dropna().iloc[-1]
        except Exception:
            features[f"{col}_last"] = np.nan
        # AUC
        try:
            x = df['Day'].values
            y = s.fillna(method='ffill').fillna(0).values
            features[f"{col}_auc"] = np.trapz(y, x)
        except Exception:
            features[f"{col}_auc"] = np.nan
        # slope
        try:
            non_na = ~s.isna()
            if non_na.sum() >= 2:
                xs = df.loc[non_na, 'Day'].values
                ys = s.dropna().values
                slope = np.polyfit(xs, ys, 1)[0]
                features[f"{col}_slope"] = slope
            else:
                features[f"{col}_slope"] = np.nan
        except Exception:
            features[f"{col}_slope"] = np.nan
    return features


def aggregate_dynamic_table(static_df: pd.DataFrame, dynamic_dir: str, obs_start=-15, obs_end=2) -> pd.DataFrame:
    """Aggregate dynamic features for every patient in static_df['patient_id'].
    Returns DataFrame aligned with static_df index containing aggregated features.
    """
    rows = []
    sample_features = {}
    # find first existing file for template
    for pid in static_df['patient_id']:
        p = os.path.join(dynamic_dir, f"{pid}.csv")
        if os.path.exists(p):
            sample_features = aggregate_dynamic_for_patient(p, obs_start, obs_end)
            break
    feat_keys = list(sample_features.keys())
    for pid in static_df['patient_id']:
        p = os.path.join(dynamic_dir, f"{pid}.csv")
        feat = aggregate_dynamic_for_patient(p, obs_start, obs_end)
        row = {k: feat.get(k, np.nan) for k in feat_keys}
        rows.append(row)
    if len(rows) == 0:
        return pd.DataFrame(index=static_df.index)
    return pd.DataFrame(rows, index=static_df.index)


def stratified_patient_split(df: pd.DataFrame, label_col: str, test_size: float = 0.3, seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame]:
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
    for train_idx, test_idx in sss.split(df, df[label_col].values):
        train_df = df.iloc[train_idx].reset_index(drop=True)
        test_df = df.iloc[test_idx].reset_index(drop=True)
    return train_df, test_df


def bootstrap_ci(metric_fn, y_true, y_score, n_boot=1000, alpha=0.95, seed=42):
    rng = np.random.RandomState(seed)
    stats = []
    n = len(y_true)
    for i in range(n_boot):
        idx = rng.randint(0, n, n)
        try:
            val = metric_fn(y_true[idx], y_score[idx])
        except Exception:
            val = np.nan
        stats.append(val)
    arr = np.array(stats)
    low = np.nanpercentile(arr, (1-alpha)/2*100)
    high = np.nanpercentile(arr, (1+alpha)/2*100)
    return float(np.nanmean(arr)), float(low), float(high)

print('Utility functions defined.')

## Load static data and basic checks
Edit the path in `CONFIG` before running. The static CSV must contain `patient_id` and `Infection_grade` columns.

In [None]:
# Load static dataframe
static_path = CONFIG['STATIC_CSV']
if not os.path.exists(static_path):
    raise FileNotFoundError(f"Static CSV not found: {static_path}")
static_df = pd.read_csv(static_path)
print('Static shape:', static_df.shape)
if 'patient_id' not in static_df.columns:
    raise ValueError("STATIC CSV must contain 'patient_id' column")
if 'Infection_grade' not in static_df.columns:
    raise ValueError("STATIC CSV must contain 'Infection_grade' column")

## Binarize label and drop constant columns
We transform `Infection_grade` into binary label: `label = 1` if Infection_grade > 2 else 0. Constant columns (zero variance) are safe to drop before split.

In [None]:
# Binarize label
static_df['label'] = binarize_label(static_df, col='Infection_grade', threshold=2)
print('Label distribution:\n', static_df['label'].value_counts())

# Drop constant columns (e.g., disease if all B-NHL)
const_cols = [c for c in static_df.columns if static_df[c].nunique() <= 1]
if const_cols:
    print('Dropping constant columns:', const_cols)
    static_df = static_df.drop(columns=const_cols)

## Define feature types
Update the `numeric_cols`, `categorical_cols`, and `ordinal_cols` lists to match your dataset columns. These lists must refer to columns present in the static CSV.

In [None]:
# === MODIFY these lists to match your dataset columns ===
# Example placeholders (replace with your real column names)
numeric_cols = [
    'age', 'bm_disease_burden'  # replace with real numeric static columns
]
categorical_cols = [
    'sex', 'bridging_therapy'  # replace with real categorical columns
]
ordinal_cols = [
    'ann_arbor_stage'  # replace with real ordinal columns (ensure order when encoding)
]

# Validate existence
all_cols = set(static_df.columns)
for lst in (numeric_cols, categorical_cols, ordinal_cols):
    for c in lst:
        if c not in all_cols:
            print(f"Warning: column {c} not found in static data. Remove or correct list.")

## Patient-level stratified split (70/30)

In [None]:
train_df, test_df = stratified_patient_split(static_df, label_col='label', test_size=CONFIG['TEST_SIZE'], seed=RANDOM_SEED)
print('Train shape:', train_df.shape, 'Test shape:', test_df.shape)
train_df['label'].value_counts(), test_df['label'].value_counts()

## Aggregate dynamic features for train and test
This step reads per-patient CSVs and computes aggregated statistics. It may be time-consuming depending on data size.

In [None]:
print('Aggregating dynamic for train...')
train_dyn = aggregate_dynamic_table(train_df, CONFIG['DYNAMIC_DIR'], CONFIG['OBS_START'], CONFIG['OBS_END'])
print('Train dynamic shape:', train_dyn.shape)

print('Aggregating dynamic for test...')
test_dyn = aggregate_dynamic_table(test_df, CONFIG['DYNAMIC_DIR'], CONFIG['OBS_START'], CONFIG['OBS_END'])
print('Test dynamic shape:', test_dyn.shape)

## Construct model input tables
Concatenate selected static features and dynamic aggregated features. Keep patient_id and label in separate objects for tracing.

In [None]:
static_features = [c for c in (numeric_cols + categorical_cols + ordinal_cols) if c in train_df.columns]
X_train_static = train_df[static_features].reset_index(drop=True)
X_test_static = test_df[static_features].reset_index(drop=True)

X_train = pd.concat([X_train_static.reset_index(drop=True), train_dyn.reset_index(drop=True)], axis=1)
X_test = pd.concat([X_test_static.reset_index(drop=True), test_dyn.reset_index(drop=True)], axis=1)

y_train = train_df['label'].values
y_test = test_df['label'].values

print('X_train shape:', X_train.shape, 'X_test shape:', X_test.shape)

## Build a no-leak preprocessing ColumnTransformer
- numeric: median imputer + StandardScaler
- categorical: most_frequent imputer + OneHotEncoder
- ordinal: most_frequent imputer + OrdinalEncoder

We treat dynamic aggregated features as numeric.

In [None]:
# dynamic columns are everything in X_train that's not static features
dyn_cols = [c for c in X_train.columns if c not in static_features]
numeric_pipeline_cols = [c for c in numeric_cols if c in X_train.columns] + dyn_cols
categorical_pipeline_cols = [c for c in categorical_cols if c in X_train.columns]
ordinal_pipeline_cols = [c for c in ordinal_cols if c in X_train.columns]

from sklearn.pipeline import make_pipeline

num_transformer = make_pipeline(SimpleImputer(strategy='median'),
                                StandardScaler())
cat_transformer = make_pipeline(SimpleImputer(strategy='most_frequent'),
                                OneHotEncoder(handle_unknown='ignore', sparse=False))
ord_transformer = make_pipeline(SimpleImputer(strategy='most_frequent'),
                                OrdinalEncoder())

preprocessor = ColumnTransformer(
    transformers=[
        ('num', num_transformer, numeric_pipeline_cols),
        ('cat', cat_transformer, categorical_pipeline_cols),
        ('ord', ord_transformer, ordinal_pipeline_cols)
    ],
    remainder='drop',
    sparse_threshold=0
)

# Fit preprocessor on training data only (no leakage)
preprocessor.fit(X_train)
X_train_t = preprocessor.transform(X_train)
X_test_t = preprocessor.transform(X_test)
print('Transformed shapes:', X_train_t.shape, X_test_t.shape)

## Optuna objective with 5-fold CV and LightGBM
We use StratifiedKFold on the training set. The objective returns mean AUPRC across folds. LightGBM's pruner integration is used for early stopping/pruning.

In [None]:
skf = StratifiedKFold(n_splits=CONFIG['N_FOLDS'], shuffle=True, random_state=RANDOM_SEED)

def objective(trial):
    param = {
        'objective': 'binary',
        'verbosity': -1,
        'boosting_type': 'gbdt',
        'num_leaves': trial.suggest_int('num_leaves', 16, 256),
        'max_depth': trial.suggest_int('max_depth', 3, 12),
        'learning_rate': trial.suggest_loguniform('learning_rate', 1e-3, 0.2),
        'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
        'subsample': trial.suggest_uniform('subsample', 0.5, 1.0),
        'colsample_bytree': trial.suggest_uniform('colsample_bytree', 0.3, 1.0),
        'reg_alpha': trial.suggest_loguniform('reg_alpha', 1e-8, 10.0),
        'reg_lambda': trial.suggest_loguniform('reg_lambda', 1e-8, 10.0),
        'min_split_gain': trial.suggest_loguniform('min_split_gain', 1e-8, 1.0)
    }
    aucs = []
    for fold, (tr_idx, val_idx) in enumerate(skf.split(X_train_t, y_train), 1):
        X_tr, X_val = X_train_t[tr_idx], X_train_t[val_idx]
        y_tr, y_val = y_train[tr_idx], y_train[val_idx]
        dtrain = lgb.Dataset(X_tr, label=y_tr)
        dval = lgb.Dataset(X_val, label=y_val, reference=dtrain)
        # scale_pos_weight
        pos = y_tr.sum()
        neg = len(y_tr) - pos
        if pos > 0:
            param['scale_pos_weight'] = float(neg/pos)
        else:
            param['scale_pos_weight'] = 1.0
        bst = lgb.train(param, dtrain, num_boost_round=2000, valid_sets=[dval],
                        early_stopping_rounds=50, verbose_eval=False,
                        callbacks=[optuna.integration.LightGBMPruningCallback(trial, 'binary_logloss')])
        pred = bst.predict(X_val, num_iteration=bst.best_iteration)
        ap = average_precision_score(y_val, pred)
        aucs.append(ap)
        trial.report(np.mean(aucs), fold)
        if trial.should_prune():
            raise optuna.TrialPruned()
    return float(np.mean(aucs))

study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=RANDOM_SEED),
                            pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=CONFIG['OPTUNA_TRIALS'], show_progress_bar=True)

print('Best trial params:', study.best_trial.params)

## Save Optuna study and plot simple optimization history

In [None]:
joblib.dump(study, os.path.join(CONFIG['OUTPUT_DIR'], 'optuna_study.pkl'))
joblib.dump(study.best_trial.params, os.path.join(CONFIG['OUTPUT_DIR'], 'best_params.pkl'))

# Simple plot of trial values
vals = [t.value for t in study.trials if t.value is not None]
plt.figure(figsize=(8,4))
plt.plot(vals, marker='o')
plt.xlabel('Trial')
plt.ylabel('AUPRC')
plt.title('Optuna: AUPRC per trial')
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'optuna_history.png'), dpi=150)
plt.close()
print('Optuna artifacts saved.')

## Retrain final model on full training set using best params and save pipeline
We will use LightGBM's sklearn wrapper within a sklearn Pipeline for easier serialization.

In [None]:
best = study.best_trial.params
# Prepare LGBMClassifier with chosen params
lgbm = lgb.LGBMClassifier(
    objective='binary',
    random_state=RANDOM_SEED,
    n_estimators=2000,
    num_leaves=best.get('num_leaves', 31),
    max_depth=best.get('max_depth', -1),
    learning_rate=best.get('learning_rate', 0.05),
    min_child_samples=best.get('min_child_samples', 20),
    subsample=best.get('subsample', 1.0),
    colsample_bytree=best.get('colsample_bytree', 1.0),
    reg_alpha=best.get('reg_alpha', 0.0),
    reg_lambda=best.get('reg_lambda', 0.0),
    min_split_gain=best.get('min_split_gain', 0.0)
)

# set scale_pos_weight on full train
pos = y_train.sum(); neg = len(y_train)-pos
if pos > 0:
    lgbm.set_params(scale_pos_weight=float(neg/pos))

final_pipeline = Pipeline([
    ('preprocessor', preprocessor),
    ('classifier', lgbm)
])

# fit final pipeline
final_pipeline.fit(X_train, y_train)
joblib.dump(final_pipeline, os.path.join(CONFIG['OUTPUT_DIR'], 'final_pipeline.pkl'))
print('Final pipeline trained and saved.')

## Evaluate final model on the independent test set
Compute point estimates and bootstrap 95% CI for AUPRC. Save metrics and plots.

In [None]:
pipe = joblib.load(os.path.join(CONFIG['OUTPUT_DIR'], 'final_pipeline.pkl'))
probs_test = pipe.predict_proba(X_test)[:, 1]
preds_test = (probs_test >= 0.5).astype(int)

metrics = {
    'AUPRC': float(average_precision_score(y_test, probs_test)),
    'ROC_AUC': float(roc_auc_score(y_test, probs_test)),
    'Precision': float(precision_score(y_test, preds_test, zero_division=0)),
    'Recall': float(recall_score(y_test, preds_test, zero_division=0)),
    'F1': float(f1_score(y_test, preds_test, zero_division=0)),
    'Brier': float(brier_score_loss(y_test, probs_test))
}
print('Test metrics:', metrics)
pd.DataFrame([metrics]).to_csv(os.path.join(CONFIG['OUTPUT_DIR'], 'test_metrics.csv'), index=False)

# Bootstrap CI for AUPRC
mean_ap, low_ap, high_ap = bootstrap_ci(average_precision_score, np.array(y_test), np.array(probs_test),
                                        n_boot=CONFIG['N_BOOTSTRAP'], alpha=0.95, seed=RANDOM_SEED)
print(f'AUPRC mean={mean_ap:.4f}, 95% CI=({low_ap:.4f}, {high_ap:.4f})')

## ROC and PR curve plots

In [None]:
# ROC
fpr, tpr, _ = roc_curve(y_test, probs_test)
plt.figure(figsize=(6,5))
plt.plot(fpr, tpr, label=f'AUC={metrics["ROC_AUC"]:.3f}')
plt.plot([0,1],[0,1],'--', color='grey')
plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC Curve'); plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'roc_curve.png'), dpi=150)
plt.close()

# PR
prec, rec, _ = precision_recall_curve(y_test, probs_test)
plt.figure(figsize=(6,5))
plt.plot(rec, prec, label=f'AUPRC={metrics["AUPRC"]:.3f}')
plt.xlabel('Recall'); plt.ylabel('Precision'); plt.title('Precision-Recall Curve'); plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'pr_curve.png'), dpi=150)
plt.close()
print('ROC/PR plots saved.')

## SHAP explanations (global & feature importance)
We use TreeExplainer for LightGBM. To save memory, we explain a random subset of training samples.

In [None]:
# Prepare a subset for SHAP
model = final_pipeline.named_steps['classifier']
preproc = final_pipeline.named_steps['preprocessor']

X_train_proc = preproc.transform(X_train)
n_samples = min(CONFIG['N_SHAP_SAMPLES'], X_train_proc.shape[0])
idx = np.random.choice(X_train_proc.shape[0], n_samples, replace=False)
X_shap = X_train_proc[idx]

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_shap)

# Summary plot (dot)
plt.figure()
shap.summary_plot(shap_values, X_shap, show=False)
plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'shap_summary.png'), bbox_inches='tight', dpi=150)
plt.close()

# Bar plot
plt.figure()
shap.summary_plot(shap_values, X_shap, plot_type='bar', show=False)
plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'shap_bar.png'), bbox_inches='tight', dpi=150)
plt.close()

# Save mean absolute shap importance
mean_abs = np.abs(shap_values).mean(axis=0)
# attempt to get feature names from preprocessor
def get_feature_names_from_preprocessor(ct: ColumnTransformer):
    names = []
    for name, trans, cols in ct.transformers_:
        if name == 'remainder' and trans == 'drop':
            continue
        if hasattr(trans, 'named_steps'):
            last = trans.named_steps[list(trans.named_steps.keys())[-1]]
            if isinstance(last, OneHotEncoder):
                ohe = last
                names.extend(ohe.get_feature_names_out(cols).tolist())
            else:
                names.extend(cols)
        else:
            names.extend(cols)
    return names

try:
    feat_names = get_feature_names_from_preprocessor(preproc)
    fi = pd.DataFrame({'feature': feat_names, 'mean_abs_shap': mean_abs})
    fi = fi.sort_values('mean_abs_shap', ascending=False)
    fi.to_csv(os.path.join(CONFIG['OUTPUT_DIR'], 'shap_feature_importance.csv'), index=False)
except Exception:
    pd.DataFrame({'idx': list(range(len(mean_abs))), 'mean_abs_shap': mean_abs}).to_csv(os.path.join(CONFIG['OUTPUT_DIR'], 'shap_feat_idx.csv'), index=False)

print('SHAP artifacts saved.')

## Save run metadata and artifacts

In [None]:
joblib.dump({'config': CONFIG, 'seed': RANDOM_SEED}, os.path.join(CONFIG['OUTPUT_DIR'], 'run_metadata.pkl'))
print('All artifacts saved to', CONFIG['OUTPUT_DIR'])

## Notes and next steps
- You can adjust OPTUNA_TRIALS to control HPO time/cost.
- If classes are extremely imbalanced consider alternative metrics or resampling inside CV folds.
- For productionize, extract modular parts into scripts under `split/`, `train/`, `eval/`, `explain/`.

---

End of notebook.