In [2]:
# --- Local Tri-Modal Training Script (DenseNet + Platinum Data) ---
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import pandas as pd
import sys
import os
import numpy as np
from tqdm.auto import tqdm

# Add src to path
sys.path.append(os.path.abspath('../src'))
from dataset import TriModalDataset
from model import WideAndDeepSurvivalModel
import torchsurv.loss

# --- CONFIGURATION ---
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Use the Platinum Cohort (High Quality)
PARQUET_PATH = '../data/processed/OAI_mega_cohort_imputed.parquet'
IMAGE_ROOT = '../data/sandbox'

# Platinum Cohort Dimensions: 
# 8 Basic (Age, BMI, WOMAC, Sex, KL*4) + 3 Advanced (KOOS, PASE, BML) = 11
CLINICAL_INPUT_DIM = 11 
BIOMARKER_INPUT_DIM = 5
BATCH_SIZE = 16
EPOCHS = 5

print(f"ðŸš€ Training Tri-Modal DenseNet on {DEVICE}")

# 1. Load Data
if not os.path.exists(PARQUET_PATH):
    raise FileNotFoundError(f"Platinum cohort not found at {PARQUET_PATH}. Please run src/build_platinum_cohort.py first.")

df = pd.read_parquet(PARQUET_PATH)

# One-Hot Encode Categorical Features
# We need to ensure all columns expected by dataset.py exist
df = pd.get_dummies(df, columns=['KL_Grade', 'Sex'], drop_first=True)

# Force-create missing columns with 0s (Robustness)
expected_cols = ['KL_Grade_1.0', 'KL_Grade_2.0', 'KL_Grade_3.0', 'KL_Grade_4.0', 'Sex_2']
for col in expected_cols:
    if col not in df.columns: 
        df[col] = 0

print(f"   Data Loaded. Shape: {df.shape}")

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# 2. Transforms (DenseNet expects 224x224)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 3. Dataset
train_dataset = TriModalDataset(train_df, IMAGE_ROOT, transform=train_transform, mode='sandbox')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 4. Model Initialization
# UPDATED: wide_input_dim=11 to match Platinum features
model = WideAndDeepSurvivalModel(wide_input_dim=CLINICAL_INPUT_DIM, bio_input_dim=BIOMARKER_INPUT_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Loss function check (MPS Safe)
try:
    base_cox = torchsurv.loss.cox
    def cox_loss(risk, events, times):
        # MPS Fix: Move to CPU for logcumsumexp calculation
        if risk.device.type == 'mps':
            return base_cox(risk.cpu(), events.cpu(), times.cpu()).to(DEVICE)
        return base_cox(risk, events, times)
except:
    # Custom fallback if library missing
    def cox_loss(risk, events, times):
        order = torch.argsort(times, descending=True)
        risk = risk[order]
        events = events[order]
        
        # CPU Fallback for cumulative sum
        risk_cpu = risk.cpu()
        log_cumsum = torch.logcumsumexp(risk_cpu, dim=0).to(risk.device)
        
        if events.sum() > 0:
            return -torch.sum(events * (risk - log_cumsum)) / events.sum()
        return torch.tensor(0.0, requires_grad=True).to(DEVICE)

# 5. Train Loop
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"   Starting Training Loop ({len(train_loader)} batches per epoch)...")
model.train()

for epoch in range(EPOCHS):
    total_loss = 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for batch_idx, batch in enumerate(progress):
        # Unpack 5 items
        img, clin, bio, event, time = batch
        
        img = img.to(DEVICE)
        clin = clin.to(DEVICE)
        bio = bio.to(DEVICE)
        event = event.to(DEVICE)
        time = time.to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward
        risk = model(img, clin, bio).squeeze()
        
        # Loss
        loss = cox_loss(risk, event, time)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Update progress bar with current loss
        progress.set_postfix({"loss": loss.item()})
        
    print(f"Epoch {epoch+1}: Average Loss {total_loss/len(train_loader):.4f}")

# 6. Save
SAVE_PATH = "../models/tri_modal_survival_model.pth"
if not os.path.exists('../models'): os.makedirs('../models')
torch.save(model.state_dict(), SAVE_PATH)
print(f"âœ… Saved DenseNet Tri-Modal model to {SAVE_PATH}")

ðŸš€ Training Tri-Modal DenseNet on mps
   Data Loaded. Shape: (3526, 28)
   Starting Training Loop (177 batches per epoch)...


Epoch 1/5:   0%|          | 0/177 [00:00<?, ?it/s]

Epoch 1: Average Loss 1.5369


Epoch 2/5:   0%|          | 0/177 [00:00<?, ?it/s]

Epoch 2: Average Loss 1.5605


Epoch 3/5:   0%|          | 0/177 [00:00<?, ?it/s]

Epoch 3: Average Loss 1.4963


Epoch 4/5:   0%|          | 0/177 [00:00<?, ?it/s]

Epoch 4: Average Loss 1.5181


Epoch 5/5:   0%|          | 0/177 [00:00<?, ?it/s]

Epoch 5: Average Loss 1.4780
âœ… Saved DenseNet Tri-Modal model to ../models/tri_modal_survival_model.pth
