# 1. Introduction
# Breast Cancer Survival Analysis
# METABRIC Dataset - Comprehensive Survival Analysis Pipeline

## 1. Introduction and Clinical Objective

### Clinical Context
Breast cancer is the most common cancer among women worldwide. Survival analysis is critical for understanding patient prognosis, treatment effectiveness, and identifying risk factors associated with survival outcomes.

### Dataset
The METABRIC (Molecular Taxonomy of Breast Cancer International Consortium) dataset contains clinical and molecular data from breast cancer patients, including:
- Clinical variables (age, tumor characteristics, treatment history)
- Molecular subtypes (PAM50 classification)
- Survival outcomes (overall survival and relapse-free survival)

### Clinical Objectives
1. **Prognostic Modeling**: Develop models to predict overall survival in breast cancer patients
2. **Risk Stratification**: Identify patient subgroups with distinct survival patterns
3. **Feature Importance**: Understand which clinical and molecular factors most strongly influence survival
4. **Treatment Insights**: Evaluate associations between treatment modalities and survival outcomes

### Outcome Definition
- Overall Survival (OS)
  - Duration: Time from diagnosis to death (in months)
  - Event: Death from any cause (Deceased) or censoring (Living)
  

### Clinical Relevance
- Help clinicians make informed treatment decisions
- Enable personalized risk assessment
- Support resource allocation and patient counseling
- Provide insights for clinical trial design

### Methods Overview
- Survival analysis using Kaplan-Meier estimation for descriptive statistics
- Cox proportional hazards models for multivariable analysis
- Machine learning approaches (Decision Trees, Random Forests) for prediction
- Group-aware splitting to ensure fair evaluation across patient subgroups
- Leakage controls to prevent data leakage and ensure clinically meaningful predictions


In [None]:
# Install required packages if not already installed
# Run this cell first if you get import errors

import sys
import subprocess

def install_package(package):
    """Install a package if not already installed"""
    try:
        __import__(package)
        print(f"✓ {package} is already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"✓ {package} installed successfully")

# Install required packages
required_packages = ["lifelines", "scikit-survival"]
for pkg in required_packages:
    install_package(pkg)

print("\n" + "="*60)
print("Package installation check complete!")
print("="*60)


# 2. Setup and Reproducibility


In [3]:
# Setup and reproducibility

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import os
import sys
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

# Survival modeling
import lifelines
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test
from lifelines.utils import restricted_mean_survival_time

# Scikit-learn core utilities
import sklearn
from sklearn import set_config
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    brier_score_loss,
    roc_curve,
    precision_recall_curve
)

# Optional scikit-survival for advanced metrics
try:
    from sksurv.metrics import cumulative_dynamic_auc, integrated_brier_score
    from sksurv.util import Surv
    SKSURV_AVAILABLE = True
except ImportError:
    print("scikit-survival not available. Some time-dependent metrics will be skipped.")
    SKSURV_AVAILABLE = False

# Global configuration for reproducibility
RANDOM_STATE = 42
SEED = 42
np.random.seed(RANDOM_STATE)

# Matplotlib defaults
plt.rcParams["figure.figsize"] = (10, 6)
plt.rcParams["axes.grid"] = True
plt.rcParams["font.size"] = 11
sns.set_style("whitegrid")

# Display versions for reproducibility
def _get_ver(mod):
    return getattr(mod, "__version__", "n/a")

print("=" * 60)
print("ENVIRONMENT AND VERSION INFORMATION")
print("=" * 60)
print(f"Python              : {sys.version.split()[0]}")
print(f"numpy               : {_get_ver(np)}")
print(f"pandas              : {_get_ver(pd)}")
print(f"lifelines           : {_get_ver(lifelines)}")
print(f"scikit-learn        : {_get_ver(sklearn)}")
print(f"scikit-survival     : {'available' if SKSURV_AVAILABLE else 'not available'}")
print(f"Random State        : {RANDOM_STATE}")
print("=" * 60)

# Pandas display options
pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 120)
pd.set_option("display.max_rows", 50)

# Set sklearn to return DataFrames where possible
set_config(transform_output="pandas")


ENVIRONMENT AND VERSION INFORMATION
Python              : 3.12.4
numpy               : 1.26.4
pandas              : 2.2.3
lifelines           : 0.30.0
scikit-learn        : 1.6.1
scikit-survival     : available
Random State        : 42


# 3. Load data and define survival variables

## Dataset Description

The dataset includes detailed clinical, molecular, and treatment-related information for breast cancer patients.  
Below are the main columns and their descriptions:

---

### Common Columns

- **Patient ID:** Unique patient identifier. *(Categorical variable)*
- **Age at Diagnosis:** Age of the patient when first diagnosed with breast cancer. *(Numerical variable)*
- **Type of Breast Surgery:** Type of breast surgery the patient underwent. *(Categorical variable)*
- **Cancer Type:** General type of cancer. *(Categorical variable)*
- **Cancer Type Detailed:** Histological subtype of the cancer. *(Categorical variable)*
- **Cellularity:** Proportion of cancer within the residual tumor bed. *(Categorical variable)*
- **Chemotherapy:** Indicates whether the patient received chemotherapy. *(Categorical variable)*
- **Pam50 + Claudin-low subtype:** Molecular classification based on the PAM50 gene signature and Claudin-low subtype. *(Categorical variable)*
- **Cohort:** Specific group or population of patients included in the study, often categorized by shared characteristics or treatment protocols. *(Numerical variable)*
- **ER Status measured by IHC:** Estrogen receptor (ER) expression status measured by immunohistochemistry. *(Categorical variable)*
- **ER Status:** Whether the tumor is positive or negative for estrogen receptors. *(Categorical variable)*
- **Neoplasm Histologic Grade:** Grade of the tumor based on how abnormal cancer cells appear under a microscope (indicator of aggressiveness). *(Numerical variable)*
- **HER2 status measured by SNP6:** HER2 receptor status measured using SNP6 technology. *(Categorical variable)*
- **HER2 Status:** Whether HER2 receptors are overexpressed in the tumor (positive/negative). *(Categorical variable)*
- **Tumor Other Histologic Subtype:** Additional histological classification of the tumor. *(Categorical variable)*
- **Hormone Therapy:** Indicates whether the patient received hormone therapy. *(Categorical variable)*
- **Inferred Menopausal State:** Patient’s menopausal status inferred from clinical data. *(Categorical variable)*
- **Integrative Cluster:** Tumor classification based on integrative molecular analysis. *(Categorical variable)*
- **Primary Tumor Laterality:** Side of the body (left/right breast) where the primary tumor is located. *(Categorical variable)*
- **Lymph nodes examined positive:** Number of lymph nodes that tested positive for cancer. *(Numerical variable)*
- **Mutation Count:** Total number of genetic mutations identified in the tumor. *(Numerical variable)*
- **Nottingham Prognostic Index:** Prognostic score assessing breast cancer outcome based on tumor histology and clinical features. *(Numerical variable)*
- **Oncotree Code:** Standardized code classifying tumor types according to the Oncotree system. *(Categorical variable)*
- **PR Status:** Progesterone receptor status of the tumor (positive/negative). *(Categorical variable)*
- **Radio Therapy:** Indicates whether the patient received radiation therapy. *(Categorical variable)*
- **Sex:** Gender of the patient. *(Categorical variable)*
- **3-Gene Classifier Subtype:** Tumor subtype based on the expression of three specific genes. *(Categorical variable)*
- **Tumor Size:** Size of the tumor (in centimeters). *(Numerical variable)*
- **Tumor Stage:** Stage of the tumor, representing the extent of cancer spread. *(Numerical variable)*

---

### Durations and Events (Survival Information)

- **Patient's Vital Status:** Indicates whether the patient is alive or deceased. *(Categorical variable)*
- **Relapse Free Status (Months):** Duration (in months) that a patient remained free from cancer recurrence after initial treatment. *(Numerical variable)*
- **Relapse Free Status:** Whether the patient experienced a relapse or remained cancer-free. *(Categorical variable)*
- **Overall Survival (Months):** Total number of months a patient survived after breast cancer diagnosis. *(Numerical variable)*
- **Overall Survival Status:** Indicates whether the patient is alive or deceased at the time of follow-up. *(Categorical variable)*

---

In [9]:
# 3.1 Load the METABRIC Breast Cancer Dataset

DATA_PATH = "Breast Cancer METABRIC.csv"

# Check if file exists
if not os.path.exists(DATA_PATH):
    raise FileNotFoundError(
        f"Error: The data file was not found at '{DATA_PATH}'. "
        "Please ensure the dataset is in the correct directory."
    )

df_raw = pd.read_csv(DATA_PATH)

df_raw.head()


Unnamed: 0,Patient ID,Age at Diagnosis,Type of Breast Surgery,Cancer Type,Cancer Type Detailed,Cellularity,Chemotherapy,Pam50 + Claudin-low subtype,Cohort,ER status measured by IHC,ER Status,Neoplasm Histologic Grade,HER2 status measured by SNP6,HER2 Status,Tumor Other Histologic Subtype,Hormone Therapy,Inferred Menopausal State,Integrative Cluster,Primary Tumor Laterality,Lymph nodes examined positive,Mutation Count,Nottingham prognostic index,Oncotree Code,Overall Survival (Months),Overall Survival Status,PR Status,Radio Therapy,Relapse Free Status (Months),Relapse Free Status,Sex,3-Gene classifier subtype,Tumor Size,Tumor Stage,Patient's Vital Status
0,MB-0000,75.65,Mastectomy,Breast Cancer,Breast Invasive Ductal Carcinoma,,No,claudin-low,1.0,Positve,Positive,3.0,Neutral,Negative,Ductal/NST,Yes,Post,4ER+,Right,10.0,,6.044,IDC,140.5,Living,Negative,Yes,138.65,Not Recurred,Female,ER-/HER2-,22.0,2.0,Living
1,MB-0002,43.19,Breast Conserving,Breast Cancer,Breast Invasive Ductal Carcinoma,High,No,LumA,1.0,Positve,Positive,3.0,Neutral,Negative,Ductal/NST,Yes,Pre,4ER+,Right,0.0,2.0,4.02,IDC,84.633333,Living,Positive,Yes,83.52,Not Recurred,Female,ER+/HER2- High Prolif,10.0,1.0,Living
2,MB-0005,48.87,Mastectomy,Breast Cancer,Breast Invasive Ductal Carcinoma,High,Yes,LumB,1.0,Positve,Positive,2.0,Neutral,Negative,Ductal/NST,Yes,Pre,3,Right,1.0,2.0,4.03,IDC,163.7,Deceased,Positive,No,151.28,Recurred,Female,,15.0,2.0,Died of Disease
3,MB-0006,47.68,Mastectomy,Breast Cancer,Breast Mixed Ductal and Lobular Carcinoma,Moderate,Yes,LumB,1.0,Positve,Positive,2.0,Neutral,Negative,Mixed,Yes,Pre,9,Right,3.0,1.0,4.05,MDLC,164.933333,Living,Positive,Yes,162.76,Not Recurred,Female,,25.0,2.0,Living
4,MB-0008,76.97,Mastectomy,Breast Cancer,Breast Mixed Ductal and Lobular Carcinoma,High,Yes,LumB,1.0,Positve,Positive,3.0,Neutral,Negative,Mixed,Yes,Post,9,Right,8.0,2.0,6.08,MDLC,41.366667,Deceased,Positive,Yes,18.55,Recurred,Female,ER+/HER2- High Prolif,40.0,2.0,Died of Disease


In [6]:
# 3.2 Define Survival Variables and Validate Outcomes

# Standardize survival variable names
# Primary outcome: Overall Survival (OS)
OS_DURATION_COL = "Overall Survival (Months)"
OS_STATUS_COL = "Overall Survival Status"
OS_PATIENT_ID = "Patient ID"

# Check if required columns exist
required_cols = [OS_DURATION_COL, OS_STATUS_COL, OS_PATIENT_ID]
missing_cols = [col for col in required_cols if col not in df_raw.columns]
if missing_cols:
    raise ValueError(f"Missing required columns: {missing_cols}")

# Create standardized survival variables
df = df_raw.copy()

# Convert survival status to binary event indicator
# "Deceased" = 1 (event occurred), "Living" = 0 (censored)
print("Overall Survival Status value counts:")
print(df[OS_STATUS_COL].value_counts())

# Create event indicator
df["event_os"] = (df[OS_STATUS_COL] == "Deceased").astype(int)
df["duration_os"] = df[OS_DURATION_COL].astype(float)

# Secondary outcome: Relapse-Free Survival (if available)
if "Relapse Free Status" in df.columns and "Relapse Free Status (Months)" in df.columns:
    RFS_STATUS_COL = "Relapse Free Status"
    RFS_DURATION_COL = "Relapse Free Status (Months)"
    df["event_rfs"] = (df[RFS_STATUS_COL] == "Recurred").astype(int)
    df["duration_rfs"] = df[RFS_DURATION_COL].astype(float)
    print("\nRelapse Free Status value counts:")
    print(df[RFS_STATUS_COL].value_counts())

print(f"\nSurvival variables created:")
print(f"- event_os: {df['event_os'].sum()} events out of {len(df)} patients ({df['event_os'].mean()*100:.1f}%)")
if "event_rfs" in df.columns:
    print(f"- event_rfs: {df['event_rfs'].sum()} events out of {df['duration_rfs'].notna().sum()} patients")


Overall Survival Status value counts:
Overall Survival Status
Deceased    1144
Living       837
Name: count, dtype: int64

Relapse Free Status value counts:
Relapse Free Status
Not Recurred    1486
Recurred        1002
Name: count, dtype: int64

Survival variables created:
- event_os: 1144 events out of 2509 patients (45.6%)
- event_rfs: 1002 events out of 2388 patients


In [None]:
# 3.3 Data Cleaning and Cohort Definition

# Keep track of original size
n_original = len(df)

# Remove rows with missing survival outcomes (cannot analyze without outcome)
print("Removing rows with missing survival outcomes...")
df = df.dropna(subset=["duration_os", "event_os"]).copy()
print(f"  Removed {n_original - len(df)} rows with missing outcomes")

# Remove rows with non-positive survival times (must be > 0)
n_before = len(df)
df = df[df["duration_os"] > 0].copy()
if len(df) < n_before:
    print(f"  Removed {n_before - len(df)} rows with non-positive survival times")

# Validate event indicator is binary
assert df["event_os"].isin([0, 1]).all(), "Event indicator must be 0 or 1"

# Check for duplicate patient IDs
n_duplicates = df[OS_PATIENT_ID].duplicated().sum()
if n_duplicates > 0:
    print(f"Warning: Found {n_duplicates} duplicate patient IDs")
    df = df.drop_duplicates(subset=[OS_PATIENT_ID], keep='first')
    print(f"  Kept first occurrence for each patient")

print(f"\nFinal cohort size: {len(df)} patients")
print(f"Event rate: {df['event_os'].mean()*100:.1f}% ({df['event_os'].sum()} events)")

# Display summary statistics
print("\n" + "="*60)
print("SURVIVAL OUTCOME SUMMARY")
print("="*60)
print(f"Duration (months) - Median: {df['duration_os'].median():.1f}, Mean: {df['duration_os'].mean():.1f}")
print(f"Duration (months) - Min: {df['duration_os'].min():.1f}, Max: {df['duration_os'].max():.1f}")
print(f"IQR: [{df['duration_os'].quantile(0.25):.1f}, {df['duration_os'].quantile(0.75):.1f}]")
print(f"Events: {df['event_os'].sum()} ({df['event_os'].mean()*100:.1f}%)")
print(f"Censored: {(df['event_os']==0).sum()} ({(df['event_os']==0).mean()*100:.1f}%)")


Removing rows with missing survival outcomes...
  Removed 528 rows with missing outcomes
  Removed 1 rows with non-positive survival times

Final cohort size: 1980 patients
Event rate: 57.8% (1144 events)

SURVIVAL OUTCOME SUMMARY
Duration (months) - Median: 116.5, Mean: 125.3
Duration (months) - Min: 0.1, Max: 355.2
IQR: [60.9, 185.1]
Events: 1144 (57.8%)
Censored: 836 (42.2%)


## 4. Single Consolidated Pipeline with Leakage Controls and Group-Aware Split


### 4.1 Leakage Control Strategy

**Leakage controls implemented:**


**Variables to EXCLUDE:**
- `Patient ID` (identifier, may use for grouping)
- `Overall Survival (Months)` (direct outcome)
- `Overall Survival Status` (direct outcome)
- `Relapse Free Status (Months)` (if using OS as outcome)
- `Relapse Free Status` (if using OS as outcome)
- `Patient's Vital Status` (likely redundant with OS status)
- Any variables derived from outcomes or measured post-treatment


In [None]:
# 4.1 Define Columns to Drop (Leakage Control)

# Columns that directly contain outcome information (LEAKAGE)
# Original column names from the dataset
OUTCOME_COLS = [
    "Overall Survival (Months)",
    "Overall Survival Status", 
    "Relapse Free Status (Months)",
    "Relapse Free Status",
    "Patient's Vital Status"
]

# Standardized outcome columns we created (also need to be excluded!)
STANDARDIZED_OUTCOME_COLS = [
    "duration_os",      # Duration for Overall Survival - OUTCOME!
    "duration_rfs",     # Duration for Relapse-Free Survival - OUTCOME!
    "event_os",         # Event indicator for OS - OUTCOME! (but we keep it in y)
    "event_rfs"         # Event indicator for RFS - OUTCOME!
]

# Identifier column (keep for grouping but exclude from features)
ID_COL = "Patient ID"

# Combine all columns to drop (original + standardized outcome columns)
cols_to_drop = []
cols_to_drop.extend([col for col in OUTCOME_COLS + [ID_COL] if col in df.columns])
cols_to_drop.extend([col for col in STANDARDIZED_OUTCOME_COLS if col in df.columns])

print("Columns to exclude from features (leakage control):")
print("\nOriginal outcome columns:")
for col in OUTCOME_COLS:
    if col in df.columns:
        print(f"  - {col}")

print("\nStandardized outcome columns (created in cell 7):")
for col in STANDARDIZED_OUTCOME_COLS:
    if col in df.columns:
        print(f"  - {col}")

if ID_COL in df.columns:
    print(f"\nIdentifier column:")
    print(f"  - {ID_COL}")

# Create X (features) and y (outcome) dataframes
# Note: event_os and duration_os are kept in y but removed from X
X = df.drop(columns=cols_to_drop, errors='ignore').copy()
y = df[["duration_os", "event_os"]].copy()

print(f"\n✓ Removed {len(cols_to_drop)} outcome/identifier columns from features")

# Also store patient IDs for group-aware splitting if needed
patient_ids = df[ID_COL].values if ID_COL in df.columns else None

print(f"\nFeature matrix shape: {X.shape}")
print(f"Outcome matrix shape: {y.shape}")



Columns to exclude from features (leakage control):

Original outcome columns:
  - Overall Survival (Months)
  - Overall Survival Status
  - Relapse Free Status (Months)
  - Relapse Free Status
  - Patient's Vital Status

Standardized outcome columns (created in cell 7):
  - duration_os
  - duration_rfs
  - event_os
  - event_rfs

Identifier column:
  - Patient ID

✓ Removed 10 outcome/identifier columns from features

Feature matrix shape: (1980, 28)
Outcome matrix shape: (1980, 2)


In [None]:
# 4.2 Create stratified train, validation, and test splits re-used across all models

def make_splits(X: pd.DataFrame, y: pd.Series, seed: int = 42):
    """
    Create stratified train/validation/test splits (60/20/20).
    
    Parameters:
    -----------
    X : pd.DataFrame
        Feature matrix
    y : pd.Series
        Outcome variable for stratification (e.g., event indicator)
    seed : int
        Random seed for reproducibility
    
    Returns:
    --------
    idx_train, idx_val, idx_test : array-like
        Indices for train, validation, and test sets
    """
    # First split off test 20 percent
    sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=seed)
    train_val_idx, test_idx = next(sss1.split(X, y))
    X_train_val, X_test = X.iloc[train_val_idx], X.iloc[test_idx]
    y_train_val, y_test = y.iloc[train_val_idx], y.iloc[test_idx]

    # Split train vs validation 75:25 within the remaining 80 percent to yield 60:20:20 overall
    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=seed)
    train_idx, val_idx = next(sss2.split(X_train_val, y_train_val))

    idx_train = X_train_val.index[train_idx]
    idx_val = X_train_val.index[val_idx]
    idx_test = X_test.index

    return idx_train, idx_val, idx_test

# Create splits using event_os for stratification
idx_train, idx_val, idx_test = make_splits(X, y["event_os"], seed=SEED)

print("Split sizes",
      "train", len(idx_train),
      "val", len(idx_val),
      "test", len(idx_test))

# Materialize split datasets
X_train, X_val, X_test = X.loc[idx_train], X.loc[idx_val], X.loc[idx_test]
y_train, y_val, y_test = y.loc[idx_train], y.loc[idx_val], y.loc[idx_test]

# Display split information
print("\n" + "="*60)
print("SPLIT SUMMARY")
print("="*60)
print(f"Train set:      {len(X_train):4d} samples ({len(X_train)/len(X)*100:5.1f}%)")
print(f"Validation set: {len(X_val):4d} samples ({len(X_val)/len(X)*100:5.1f}%)")
print(f"Test set:       {len(X_test):4d} samples ({len(X_test)/len(X)*100:5.1f}%)")
print(f"Total:          {len(X):4d} samples")

# Check event rates across splits
print("\n" + "="*60)
print("EVENT RATES ACROSS SPLITS (should be similar)")
print("="*60)
event_rate_train = y_train["event_os"].mean()
event_rate_val = y_val["event_os"].mean()
event_rate_test = y_test["event_os"].mean()
print(f"Train:      {event_rate_train*100:5.1f}% ({y_train['event_os'].sum():3d} events)")
print(f"Validation: {event_rate_val*100:5.1f}% ({y_val['event_os'].sum():3d} events)")
print(f"Test:       {event_rate_test*100:5.1f}% ({y_test['event_os'].sum():3d} events)")


Split sizes train 1188 val 396 test 396

SPLIT SUMMARY
Train set:      1188 samples ( 60.0%)
Validation set:  396 samples ( 20.0%)
Test set:        396 samples ( 20.0%)
Total:          1980 samples

EVENT RATES ACROSS SPLITS (should be similar)
Train:       57.7% (686 events)
Validation:  57.8% (229 events)
Test:        57.8% (229 events)


### 4.3 Feature Type Detection and Preprocessing Pipeline

1. **Feature Type Detection**
   - Automatically identifies which features are categorical (discrete categories) vs numeric (continuous numbers)
   - Categorical features: text/object types or integers with few unique values
   - Numeric features: continuous variables like age, tumor size, mutation count

2. **Preprocessing Pipeline Construction**
   - Creates separate transformation pipelines for categorical and numeric features
   - **Numeric pipeline**: Imputes missing values with median → Standardizes (z-score normalization)
   - **Categorical pipeline**: Imputes missing values with most frequent category → One-hot encodes

3. **Leakage Prevention**
   - Pipeline is fitted **ONLY** on training data
   - Validation and test sets are transformed using parameters learned from training
   - This ensures realistic performance estimates and prevents data leakage


In [None]:
# 4.3 Feature Type Detection and Preprocessing Pipeline

def detect_feature_types(X):
    # Identify categorical columns (object type or with few unique values)
    cat_cols = []
    num_cols = []
    
    for col in X.columns:
        if X[col].dtype == 'object':
            cat_cols.append(col)
        elif X[col].dtype in ['int64', 'float64']:
            # Consider low-cardinality integer columns as categorical if they have few unique values
            n_unique = X[col].nunique()
            if n_unique <= 10 and n_unique < len(X) * 0.05:  # Less than 5% unique values
                cat_cols.append(col)
            else:
                num_cols.append(col)
        else:
            # Default to numeric for other types
            num_cols.append(col)
    
    return cat_cols, num_cols


# Detect feature types from training data
cat_cols, num_cols = detect_feature_types(X_train)

print("="*60)
print("FEATURE TYPE DETECTION")
print("="*60)
print(f"Categorical features: {len(cat_cols)}")
print(f"Numeric features:     {len(num_cols)}")

if len(cat_cols) > 0:
    print(f"\nCategorical features (first 10):")
    for col in cat_cols[:10]:
        n_unique = X_train[col].nunique()
        n_missing = X_train[col].isnull().sum()
        print(f"  {col:40s} - {n_unique:3d} unique values, {n_missing:4d} missing")

if len(num_cols) > 0:
    print(f"\nNumeric features (first 10):")
    for col in num_cols[:10]:
        n_missing = X_train[col].isnull().sum()
        mean_val = X_train[col].mean()
        print(f"  {col:40s} - {n_missing:4d} missing, mean={mean_val:.2f}")

# 4.5 Build Consolidated Preprocessing Pipeline

# Numeric feature pipeline
# - Impute missing values with median (fit on train only)
# - Standardize to zero mean and unit variance (fit on train only)
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
])

# Categorical feature pipeline
# - Impute missing values with most frequent category (fit on train only)
# - One-hot encode (fit on train only, ignore unknown categories at transform time)
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(drop='if_binary', handle_unknown='ignore', sparse_output=False))
])

# Combined preprocessor
# This ensures all transformations are fit on training data only
# and applied consistently to validation and test sets
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, num_cols),
        ('cat', categorical_transformer, cat_cols)
    ],
    remainder='drop',  # Drop any columns not explicitly handled
    verbose_feature_names_out=False  # Keep original feature names where possible
)

print("\n" + "="*60)
print("PREPROCESSING PIPELINE BUILT")
print("="*60)
print("\nPipeline components:")
print("  1. Numeric features:")
print("     - Median imputation")
print("     - Standardization (z-score)")
print(f"     - {len(num_cols)} features")
print("\n  2. Categorical features:")
print("     - Most frequent imputation")
print("     - One-hot encoding")
print(f"     - {len(cat_cols)} features")

# Fit preprocessor on training data only
print("\n" + "="*60)
print("FITTING PREPROCESSOR ON TRAINING DATA")
print("="*60)
preprocessor.fit(X_train)

# Transform all splits
print("\nTransforming datasets...")
X_train_transformed = preprocessor.transform(X_train)
X_val_transformed = preprocessor.transform(X_val)
X_test_transformed = preprocessor.transform(X_test)

print(f"\nTransformed feature dimensions:")
print(f"  Train:      {X_train_transformed.shape}")
print(f"  Validation: {X_val_transformed.shape}")
print(f"  Test:       {X_test_transformed.shape}")

# Create transformed DataFrames for convenience
# Convert to DataFrame if not already (when set_config is used)
if hasattr(X_train_transformed, 'columns'):
    Xt_train = X_train_transformed
    Xt_val = X_val_transformed
    Xt_test = X_test_transformed
else:
    # Convert numpy arrays to DataFrames with feature names
    try:
        feature_names = preprocessor.get_feature_names_out()
        Xt_train = pd.DataFrame(X_train_transformed, 
                                index=X_train.index, 
                                columns=feature_names)
        Xt_val = pd.DataFrame(X_val_transformed, 
                             index=X_val.index, 
                             columns=feature_names)
        Xt_test = pd.DataFrame(X_test_transformed, 
                               index=X_test.index, 
                               columns=feature_names)
    except:
        # Fallback if feature names not available
        Xt_train = pd.DataFrame(X_train_transformed, index=X_train.index)
        Xt_val = pd.DataFrame(X_val_transformed, index=X_val.index)
        Xt_test = pd.DataFrame(X_test_transformed, index=X_test.index)

print("\n" + "="*60)
print("TRANSFORMED DATA SUMMARY")
print("="*60)
print(f"Transformed DataFrames created:")
print(f"  Xt_train: {Xt_train.shape}")
print(f"  Xt_val:   {Xt_val.shape}")
print(f"  Xt_test:  {Xt_test.shape}")

# Quick leakage sanity check: confirm no outcome columns survived
print("\n" + "="*60)
print("LEAKAGE CHECK: Verifying no outcome columns in features")
print("="*60)
leakage_terms = [
    "overall survival",
    "relapse free status", 
    "patient's vital status",
    "event_os",
    "duration_os",
    "patient id"
]

leakage_found = []
for term in leakage_terms:
    matching_cols = [c for c in Xt_train.columns if term.lower() in str(c).lower()]
    if matching_cols:
        leakage_found.extend(matching_cols)

if leakage_found:
    print(f"⚠ WARNING: Found potential leakage columns: {leakage_found}")
else:
    print("✓ No leakage detected - all outcome-related columns properly excluded")

# Additional assertions
assert not any("overall survival" in str(c).lower() for c in Xt_train.columns), \
    "Leakage detected: Overall Survival column found in features"
assert not any("relapse free status" in str(c).lower() for c in Xt_train.columns), \
    "Leakage detected: Relapse Free Status column found in features"
assert not any("patient id" in str(c).lower() for c in Xt_train.columns), \
    "Leakage detected: Patient ID column found in features"

print("✓ All leakage checks passed!")
print("\n✓ Preprocessing pipeline completed successfully!")
print("  All transformations were fit on training data only")
print("  Validation and test sets transformed using training fit parameters")
print("  This prevents data leakage from validation/test into training")


FEATURE TYPE DETECTION
Categorical features: 23
Numeric features:     5

Categorical features (first 10):
  Type of Breast Surgery                   -   2 unique values,   16 missing
  Cancer Type                              -   2 unique values,    0 missing
  Cancer Type Detailed                     -   8 unique values,    0 missing
  Cellularity                              -   3 unique values,   36 missing
  Chemotherapy                             -   2 unique values,    1 missing
  Pam50 + Claudin-low subtype              -   7 unique values,    1 missing
  Cohort                                   -   5 unique values,    0 missing
  ER status measured by IHC                -   2 unique values,   24 missing
  ER Status                                -   2 unique values,    0 missing
  Neoplasm Histologic Grade                -   3 unique values,   50 missing

Numeric features (first 10):
  Age at Diagnosis                         -    0 missing, mean=60.99
  Lymph nodes examined p