In [2]:
# --- Local Tri-Modal Training Script ---
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import pandas as pd
import sys
import os
import numpy as np

# Add src to path
sys.path.append(os.path.abspath('../src'))
from dataset import TriModalDataset
from model import WideAndDeepSurvivalModel
import torchsurv.loss

# Config
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
PARQUET_PATH = '../data/processed/OAI_tri_modal_real.parquet'
IMAGE_ROOT = '../data/sandbox'
BATCH_SIZE = 32
EPOCHS = 5

# 1. Load Data
df = pd.read_parquet(PARQUET_PATH)
# One-Hot Encode Clinical Features
df = pd.get_dummies(df, columns=['KL_Grade', 'Sex'], drop_first=True)
for col in ['KL_Grade_1.0', 'KL_Grade_2.0', 'KL_Grade_3.0', 'KL_Grade_4.0', 'Sex_2']:
    if col not in df.columns: df[col] = 0

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# 2. Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 3. Dataset
train_dataset = TriModalDataset(train_df, IMAGE_ROOT, transform=train_transform, mode='sandbox')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 4. Model
model = WideAndDeepSurvivalModel(wide_input_dim=8, bio_input_dim=5).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Loss function check
try:
    cox_loss = torchsurv.loss.cox
except:
    # Custom fallback with MPS Fix
    def cox_loss(risk, events, times):
        # Sort by time (descending)
        order = torch.argsort(times, descending=True)
        risk = risk[order]
        events = events[order]
        
        # --- MPS FIX: Compute logcumsumexp on CPU ---
        # We move 'risk' to CPU, calc logcumsumexp, then move back to original device
        risk_cpu = risk.cpu()
        log_cumsum_cpu = torch.logcumsumexp(risk_cpu, dim=0)
        log_cumsum = log_cumsum_cpu.to(risk.device)
        # --------------------------------------------
        
        if events.sum() > 0:
            return -torch.sum(events * (risk - log_cumsum)) / events.sum()
        return torch.tensor(0.0, requires_grad=True).to(risk.device)

# 5. Train
print(f"Training Tri-Modal Model on {DEVICE}...")
model.train()

for epoch in range(EPOCHS):
    total_loss = 0
    for batch in train_loader:
        # Unpack 5 items!
        img, clin, bio, event, time = batch
        img, clin, bio = img.to(DEVICE), clin.to(DEVICE), bio.to(DEVICE)
        event, time = event.to(DEVICE), time.to(DEVICE)
        
        optimizer.zero_grad()
        risk = model(img, clin, bio).squeeze()
        loss = cox_loss(risk, event, time)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}: Loss {total_loss/len(train_loader):.4f}")

# 6. Save
SAVE_PATH = "../models/tri_modal_survival_model.pth"
torch.save(model.state_dict(), SAVE_PATH)
print(f"Saved updated model to {SAVE_PATH}")

Training Tri-Modal Model on mps...
Epoch 1: Loss 2.0794
Epoch 2: Loss 2.0327
Epoch 3: Loss 2.0060
Epoch 4: Loss 2.0245
Epoch 5: Loss 2.0284
Saved updated model to ../models/tri_modal_survival_model.pth
