In [3]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sksurv.metrics import concordance_index_censored
import torchsurv.loss
import torchvision.models as models

sys.path.append(os.path.abspath('../src'))
from dataset import TriModalDataset
from model import WideAndDeepSurvivalModel

# --- CONFIGURATION ---
# Use the Platinum Cohort
PARQUET_PATH = '../data/processed/OAI_platinum_cohort.parquet' 
IMAGE_ROOT = '../data/sandbox'

# Tuning for Small Dataset (283 samples)
# We increase epochs because dataset is small (it sees less data per epoch)
# We decrease batch size to update weights more often
BATCH_SIZE = 16 
EPOCHS = 50 
LR = 2e-4 # Slightly higher LR for small batches

DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Running Platinum Ablation on {DEVICE}")

# --- MODEL DEFINITIONS ---
# Clinical Input Dim = 11 (8 Basic + 3 Advanced)

class ClinicalOnlyModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4), # Higher dropout for small data
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    def forward(self, clin, img=None, bio=None):
        return self.network(clin)

class BiModalModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # Image
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.img_enc = nn.Sequential(*list(resnet.children())[:-1])
        # Clinical
        self.clin_enc = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.BatchNorm1d(32),
            nn.ReLU()
        )
        # Fusion
        self.fusion = nn.Sequential(
            nn.Linear(512 + 32, 32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(32, 1)
        )
    def forward(self, clin, img, bio=None):
        i = self.img_enc(img).view(img.size(0), -1)
        c = self.clin_enc(clin)
        return self.fusion(torch.cat((i, c), dim=1))

# --- DATA LOADING ---
df = pd.read_parquet(PARQUET_PATH)
print(f"Loaded Platinum Data: {df.shape}")

# Fix One-Hot Encoding for KL Grade (Ensure 1.0, 2.0 etc exist)
df['KL_Grade'] = df['KL_Grade'].astype(float) # Ensure float for consistency
df = pd.get_dummies(df, columns=['KL_Grade', 'Sex'], drop_first=True)

# Ensure all expected columns exist (fill with 0 if missing)
for i in [1.0, 2.0, 3.0, 4.0]:
    if f'KL_Grade_{i}' not in df.columns: df[f'KL_Grade_{i}'] = 0
if 'Sex_2' not in df.columns: df['Sex_2'] = 0

# Split
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42) # 30% test for small data

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = TriModalDataset(train_df, IMAGE_ROOT, transform=transform, mode='sandbox')
test_ds = TriModalDataset(test_df, IMAGE_ROOT, transform=transform, mode='sandbox')
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# --- TRAINING ENGINE ---
def cox_loss(risk, events, times):
    # MPS-Safe Cox Loss
    risk_cpu = risk.cpu()
    log_cumsum = torch.logcumsumexp(risk_cpu, dim=0)
    log_cumsum = log_cumsum.to(risk.device)
    
    if events.sum() > 0:
        return -torch.sum(events * (risk - log_cumsum)) / events.sum()
    return torch.tensor(0.0, requires_grad=True).to(risk.device)

def train_eval(name, model):
    print(f"\nTraining {name}...")
    optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
    model.to(DEVICE)
    
    for ep in range(EPOCHS):
        model.train()
        for img, clin, bio, evt, time in train_loader:
            if evt.sum() == 0: continue
            img, clin, bio = img.to(DEVICE), clin.to(DEVICE), bio.to(DEVICE)
            evt, time = evt.to(DEVICE), time.to(DEVICE)
            
            optimizer.zero_grad()
            # Route inputs based on model type
            if "Clinical" in name and "Image" not in name: out = model(clin)
            elif "Bi" in name: out = model(clin, img)
            else: out = model(img, clin, bio)
            
            loss = cox_loss(out.squeeze(), evt, time)
            loss.backward()
            optimizer.step()
            
    # Eval
    model.eval()
    risks, evts, times = [], [], []
    with torch.no_grad():
        for img, clin, bio, evt, time in test_loader:
            img, clin, bio = img.to(DEVICE), clin.to(DEVICE), bio.to(DEVICE)
            if "Clinical" in name and "Image" not in name: out = model(clin)
            elif "Bi" in name: out = model(clin, img)
            else: out = model(img, clin, bio)
            
            risks.extend(out.squeeze().cpu().numpy())
            evts.extend(evt.numpy().astype(bool))
            times.extend(time.numpy())
            
    try:
        c = concordance_index_censored(np.array(evts), np.array(times), np.array(risks))[0]
        print(f"   >> C-Index: {c:.4f}")
        return c
    except:
        return 0.5

# --- RUN ABLATION ---
# 11 Inputs = 8 Basic + 3 Platinum
res_a = train_eval("Model A (Clinical)", ClinicalOnlyModel(11))
res_b = train_eval("Model B (Bi-Modal)", BiModalModel(11))
res_c = train_eval("Model C (Tri-Modal)", WideAndDeepSurvivalModel(wide_input_dim=11, bio_input_dim=5))

# Plot
res_df = pd.DataFrame({
    'Model': ['Clinical', 'Bi-Modal', 'Tri-Modal'],
    'C-Index': [res_a, res_b, res_c]
})
print("\nFinal Results:")
print(res_df)

plt.figure(figsize=(6,4))
sns.barplot(data=res_df, x='Model', y='C-Index', palette='Blues')
plt.axhline(0.5, color='red', linestyle='--')
plt.ylim(0.4, 1.0)
plt.title("Platinum Cohort Ablation Results")
plt.show()

Running Platinum Ablation on mps
Loaded Platinum Data: (283, 17)

Training Model A (Clinical)...
   >> C-Index: 0.4612

Training Model B (Bi-Modal)...


KeyboardInterrupt: 