# Phenotype clustering with top XGBoost features

Use top-50 features from `data/feature_importance_xgboost.csv`, cluster ICU phenotypes (k=2,3,4), and visualize cluster mortality and key feature patterns.

In [None]:
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from IPython.display import display
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler

ROOT = Path('..').resolve()
sys.path.append(str(ROOT))

try:
    from src.preprocess import build_preprocessor
    HAS_PREPROCESS = True
except Exception as exc:
    print(f"Failed to import build_preprocessor, using numeric-only fallback. Reason: {exc}")
    HAS_PREPROCESS = False


In [None]:
# Paths and parameters
DATA_PATH = ROOT / 'data' / 'dataset.csv'
FI_PATH = ROOT / 'data' / 'feature_importance_xgboost.csv'
SAVE_DIR = ROOT / 'results' / 'phenotypes'
TOP_N = 50
K_LIST = [2, 3, 4]
KEY_VARS = [
    'age',
    'd1_lactate_max',
    'd1_creatinine_max',
    'd1_sysbp_min',
    'heart_rate_apache',
    'gcs_motor_apache',
    'ventilated_apache',
    'urineoutput_apache',
]
SAVE_DIR.mkdir(parents=True, exist_ok=True)

JITTER_STD = 0.02  # stddev for PCA jitter to reduce overplot
MAX_POINTS = 5000  # subsample for scatter if very large


In [None]:
def load_top_features(fi_path: Path, df_cols, top_n: int):
    """Return top encoded feature names and mapped raw column names present in df_cols."""
    fi = pd.read_csv(fi_path)
    top_feats = fi.sort_values('importance', ascending=False).head(top_n)['feature'].tolist()
    raw_cols = []
    for f in top_feats:
        name = f.split('__', 1)[-1] if '__' in f else f
        # direct match
        if name in df_cols:
            raw_cols.append(name)
            continue
        # try categorical one-hot like ethnicity_Caucasian -> ethnicity
        matches = [col for col in df_cols if name.startswith(col + '_')]
        if matches:
            raw_cols.append(matches[0])
    raw_cols = list(dict.fromkeys(raw_cols))  # dedupe
    return top_feats, raw_cols


In [None]:
# Load data and prepare feature matrix using top features
raw_df = pd.read_csv(DATA_PATH)
if 'hospital_death' not in raw_df.columns:
    raise RuntimeError('dataset.csv 缺少 hospital_death 列')

top_feats_encoded, raw_cols = load_top_features(FI_PATH, raw_df.columns, TOP_N)
print('Top encoded features (head):', top_feats_encoded[:10])
print('Mapped raw columns used:', raw_cols)

present_cols = [c for c in raw_cols if c in raw_df.columns]
missing = [c for c in raw_cols if c not in raw_df.columns]
if missing:
    print('Warning: missing columns skipped:', missing)

feature_df = raw_df[present_cols + ['hospital_death']].copy()


In [None]:
def assign_risk_labels(cluster_risk: pd.Series):
    """Map clusters to High/Medium/Low based on mortality rate."""
    if cluster_risk is None or cluster_risk.empty:
        return {}
    ordered = cluster_risk.sort_values(ascending=False)
    n = len(ordered)
    if n == 1:
        labels = pd.Series(['Medium'], index=ordered.index)
    elif n == 2:
        labels = pd.Series(['High', 'Low'], index=ordered.index)
    else:
        # split into tertiles by rank
        positions = pd.Series(range(1, n + 1), index=ordered.index)
        labels = pd.qcut(positions, q=3, labels=['High', 'Medium', 'Low'])
    return labels.to_dict()


In [None]:
# Build feature matrix (X_dense) from selected columns
if HAS_PREPROCESS:
    X_raw = feature_df.drop(columns=['hospital_death'])
    y = feature_df['hospital_death'].astype(int)
    preprocessor = build_preprocessor(X_raw, scale_numeric=True)
    X_matrix = preprocessor.fit_transform(X_raw)
    X_dense = X_matrix.toarray() if hasattr(X_matrix, 'toarray') else X_matrix
    feat_names = preprocessor.get_feature_names_out()
else:
    numeric_cols = feature_df.select_dtypes(include=['int64','float64']).columns.tolist()
    if 'hospital_death' in numeric_cols:
        numeric_cols.remove('hospital_death')
    X_numeric = feature_df[numeric_cols].fillna(feature_df[numeric_cols].median())
    scaler = StandardScaler()
    X_dense = scaler.fit_transform(X_numeric)
    feat_names = numeric_cols
    y = feature_df['hospital_death'].astype(int)

print('Feature matrix shape:', X_dense.shape)


In [None]:
# Run clustering for k in K_LIST
results = []
for k in K_LIST:
    model = KMeans(n_clusters=k, random_state=42)
    clusters = model.fit_predict(X_dense)
    pca = PCA(n_components=2, random_state=42)
    X_pca = pca.fit_transform(X_dense)
    sil = silhouette_score(X_dense, clusters) if k < len(X_dense) else np.nan
    df_k = feature_df.copy()
    df_k['cluster'] = clusters
    cluster_risk = df_k.groupby('cluster')['hospital_death'].mean().sort_index()
    risk_labels = assign_risk_labels(cluster_risk)
    if risk_labels:
        df_k['risk_group'] = df_k['cluster'].map(risk_labels)
    results.append({'k': k, 'df': df_k, 'pca': X_pca, 'silhouette': sil, 'cluster_risk': cluster_risk, 'risk_labels': risk_labels})
    print(f'k={k}: silhouette={sil:.3f}, death rates=
{cluster_risk}')


In [None]:
# Visualize PCA by target (hospital_death) for each k
for res in results:
    k = res['k']
    df_k = res['df']
    X_pca = res['pca']
    k_dir = SAVE_DIR / f'k{k}'
    k_dir.mkdir(parents=True, exist_ok=True)

    # Subsample and jitter to reduce overplot
    n = len(df_k)
    idx = np.arange(n)
    if n > MAX_POINTS:
        rng = np.random.default_rng(42)
        idx = rng.choice(n, size=MAX_POINTS, replace=False)
    X_plot = X_pca[idx] + np.random.normal(0, JITTER_STD, size=(len(idx), X_pca.shape[1]))

    plt.figure(figsize=(6,5))
    scatter_t = plt.scatter(
        X_plot[:,0],
        X_plot[:,1],
        c=df_k['hospital_death'].to_numpy()[idx],
        cmap='coolwarm',
        alpha=0.4,
        s=12,
        edgecolors='none'
    )
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.title(f'PCA (k={k}) — colored by hospital_death')
    cbar = plt.colorbar(scatter_t)
    cbar.set_label('hospital_death')
    plt.savefig(k_dir / 'pca_target.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Mortality bar
    plt.figure(figsize=(5,4))
    res['cluster_risk'].plot(kind='bar', color='steelblue')
    plt.ylabel('Death rate')
    plt.title(f'Cluster death rate (k={k})')
    plt.xlabel('Cluster')
    plt.savefig(k_dir / 'death_rate_bar.png', dpi=150, bbox_inches='tight')
    plt.show()


In [None]:
# Key feature patterns per cluster (z-scored means for KEY_VARS)
for res in results:
    k = res['k']
    df_k = res['df']
    k_dir = SAVE_DIR / f'k{k}'

    available_keys = [c for c in KEY_VARS if c in df_k.columns]
    if not available_keys:
        print(f'k={k}: no KEY_VARS found in dataset, skip heatmap/boxplots')
        continue

    scaler = StandardScaler()
    z = pd.DataFrame(
        scaler.fit_transform(df_k[available_keys]),
        columns=available_keys,
        index=df_k.index,
    )
    mean_z = z.join(df_k['cluster']).groupby('cluster').mean()

    plt.figure(figsize=(8, max(4, len(available_keys)*0.35)))
    plt.imshow(mean_z.T, aspect='auto', cmap='coolwarm')
    plt.yticks(range(len(available_keys)), available_keys)
    plt.xticks(range(len(mean_z.index)), mean_z.index)
    plt.colorbar(label='Mean z-score')
    plt.title(f'Key feature z-scored means by cluster (k={k})')
    plt.tight_layout()
    plt.savefig(k_dir / 'key_feature_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Boxplots for key vars
    for col in available_keys:
        plt.figure(figsize=(5,4))
        df_k.boxplot(column=col, by='cluster')
        plt.title(f'{col} by cluster (k={k})')
        plt.suptitle('')
        plt.xlabel('Cluster')
        plt.ylabel(col)
        plt.tight_layout()
        plt.savefig(k_dir / f'box_{col}.png', dpi=150, bbox_inches='tight')
        plt.show()


## How to use / interpret
- Run all cells to generate PCA (by cluster/target), death-rate bars, and key-feature heatmaps/boxplots for k=2,3,4 under `results/phenotypes/k*/`.
- Use the heatmap + boxplots to label clusters with clinical phenotypes (e.g., high lactate + ventilated → high-risk MOF; younger + low markers → low-risk infection).
- Compare mortality bars per k to pick the granularity with clear separation (and decent silhouette score).

In [None]:
# Extra PCA by target (no jitter)
for res in results:
    k_val = res['k']
    df_k = res['df']
    X_pca = res['pca']
    k_dir = SAVE_DIR / f'k{k_val}'
    k_dir.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(6,5))
    scatter = plt.scatter(
        X_pca[:,0],
        X_pca[:,1],
        c=df_k['hospital_death'],
        cmap='coolwarm',
        alpha=0.5,
        s=10,
        edgecolors='none'
    )
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.title(f'PCA (k={k_val}) — colored by hospital_death (full)')
    cbar = plt.colorbar(scatter)
    cbar.set_label('hospital_death')
    out_path = k_dir / 'pca_target_full.png'
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Saved: {out_path}')
