In [None]:
# --- CELL 1: Setup, Sync & Install ---
import sys
import os

# 1. Clean up old repo to force a fresh clone
!rm -rf oa-survival-model

# 2. Clone your repo
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
git_token = user_secrets.get_secret("GIT_TOKEN")
username = "AbhiGTM19"

!git clone https://{username}:{git_token}@github.com/{username}/oa-survival-model.git

# 3. Install Dependencies
# We force install the specific versions that work
%cd oa-survival-model
!pip install -r requirements.txt
!pip install torchsurv  # Ensure this is installed even if missing in requirements
%cd ..

# 4. Add source code to path
sys.path.append('/kaggle/working/oa-survival-model/src')

print("Environment Ready & Dependencies Installed.")

In [None]:
# --- CELL 2: Build & Impute Data (New Pipeline) ---
# This step generates the high-quality imputed dataset
%cd oa-survival-model
!python src/build_mega_cohort.py
!python src/impute_biomarkers.py
%cd ..

print("âœ… Data Built & Imputed Successfully.")

In [None]:
# --- CELL 3: Imports & Configuration ---
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
import torchsurv.loss
from torch.cuda.amp import GradScaler, autocast # Mixed Precision
import numpy as np
import os

# Import your custom modules
from model import WideAndDeepSurvivalModel, SemanticEncoder
from dataset import TriModalDataset

# Configuration
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    GPU_COUNT = torch.cuda.device_count()
    print(f"Training on: {GPU_COUNT} x NVIDIA GPU(s)")
else:
    DEVICE = torch.device("cpu")
    GPU_COUNT = 0
    print("Training on: CPU")

# Paths
# We use the newly generated IMPUTED dataset
PARQUET_PATH = 'oa-survival-model/data/processed/OAI_mega_cohort_imputed.parquet'
IMAGE_ROOT = '/kaggle/input/knee-osteoarthritis-dataset-with-severity' 

# Hyperparameters
BATCH_SIZE = 32 * max(1, GPU_COUNT) 
EPOCHS = 20 # Survival model converges fast, 20 is usually enough
LEARNING_RATE = 1e-4

In [None]:
# --- CELL 4: Data Prep (Tri-Modal) ---

# 1. Load Data
df = pd.read_parquet(PARQUET_PATH)
print(f"Loaded Cohort: {len(df)} patients")

# 2. Preprocessing (One-Hot Encoding for Categoricals)
# KL_Grade is categorical 0-4
df = pd.get_dummies(df, columns=['KL_Grade'], drop_first=False) # Keep all for clarity or drop_first if preferred
# Ensure expected columns exist (0.0 to 4.0)
expected_kl = ['KL_Grade_0.0', 'KL_Grade_1.0', 'KL_Grade_2.0', 'KL_Grade_3.0', 'KL_Grade_4.0']
for col in expected_kl:
    if col not in df.columns:
        df[col] = 0

# Sex is already 0/1 from imputation script, but let's be safe
if 'Sex_2' not in df.columns and 'Sex' in df.columns:
    # If Sex is 0/1, we might treat it as is, or one-hot. 
    # The dataset.py expects 'Sex_2' (Female) if using one-hot, OR just 'Sex' if numeric.
    # Let's check dataset.py logic. It uses 'Sex_2' in the list.
    # We will create Sex_2 for compatibility.
    df['Sex_2'] = df['Sex'] # Assuming 1=Female

# 3. Split
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# 4. Transforms (Standard ResNet RGB for Survival Model)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Grayscale(num_output_channels=3), # Force 3 channels for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 5. Datasets
# mode='prod' tries to match IDs. If images missing, it blacks them out (safe).
# mode='sandbox' picks random images. 
# USE 'sandbox' if your IDs don't match filenames yet.
DATASET_MODE = 'sandbox' 

train_dataset = TriModalDataset(train_df, IMAGE_ROOT, transform=train_transform, mode=DATASET_MODE)
val_dataset = TriModalDataset(val_df, IMAGE_ROOT, transform=val_transform, mode=DATASET_MODE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

In [None]:
# --- CELL 5: Tri-Modal Training Loop (Fixed Dimensions) ---

# 1. Initialize Model (Corrected Input Dimensions)
# wide_input_dim = 15 (Updated Feature Set)
# bio_input_dim = 5   (Biomarkers)
model = WideAndDeepSurvivalModel(wide_input_dim=15, bio_input_dim=5).to(DEVICE)

if GPU_COUNT > 1:
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler() 

# 2. Robust Cox Loss Function
if hasattr(torchsurv.loss, 'cox'):
    base_cox = torchsurv.loss.cox
    def safe_cox_loss(risk, events, times):
        return base_cox(risk.float(), events.float(), times.float())
    cox_loss_func = safe_cox_loss
else:
    # Fallback
    def custom_cox_loss(risk, events, times):
        risk = risk.float(); events = events.float(); times = times.float()
        order = torch.argsort(times, descending=True)
        risk = risk[order]
        events = events[order]
        log_cumsum = torch.logcumsumexp(risk, dim=0)
        if events.sum() > 0:
            return -torch.sum(events * (risk - log_cumsum)) / events.sum()
        return torch.tensor(0.0, requires_grad=True, device=DEVICE)
    cox_loss_func = custom_cox_loss

# 3. Loop
print("Starting Survival Training...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    # Unpack 5 Items: Image, Clin, Bio, Event, Time
    for batch_idx, (images, clinical, bio, events, times) in enumerate(train_loader):
        if events.sum() == 0: continue # Skip empty event batches

        # Move to GPU
        images = images.to(DEVICE)
        clinical = clinical.to(DEVICE)
        bio = bio.to(DEVICE)
        events = events.to(DEVICE)
        times = times.to(DEVICE)
        
        optimizer.zero_grad()
        
        with autocast():
            # Forward Pass (Tri-Modal)
            risk_scores = model(images, clinical, bio).squeeze()
            loss = cox_loss_func(risk_scores, events, times)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{EPOCHS} | Avg Loss: {total_loss/len(train_loader):.4f}")

# Save the Tri-Modal Model
torch.save(model.module.state_dict() if GPU_COUNT > 1 else model.state_dict(), "tri_modal_survival_model.pth")
print("Saved Tri-Modal Model.")

In [None]:
# --- CELL 6: Evaluation ---
from sksurv.metrics import concordance_index_censored
import numpy as np

print("Evaluating on Validation Set...")
model.eval()

val_risk_scores = []
val_events = []
val_times = []

with torch.no_grad():
    for images, clinical, bio, events, times in val_loader:
        images = images.to(DEVICE)
        clinical = clinical.to(DEVICE)
        bio = bio.to(DEVICE)
        
        # Forward
        outputs = model(images, clinical, bio).squeeze()
        
        # Store
        val_risk_scores.extend(outputs.cpu().numpy())
        val_events.extend(events.numpy().astype(bool))
        val_times.extend(times.numpy())

c_index = concordance_index_censored(
    np.array(val_events),
    np.array(val_times),
    np.array(val_risk_scores)
)

print(f"Validation C-index: {c_index[0]:.4f}")

In [None]:
# --- CELL 7: Save Survival Model ---
SAVE_PATH = "tri_modal_survival_model.pth"
# Unwrap DataParallel if necessary
state_dict = model.module.state_dict() if GPU_COUNT > 1 else model.state_dict()
torch.save(state_dict, SAVE_PATH)
print(f"Saved: {SAVE_PATH}")

In [None]:
# --- CELL 8: GenAI Setup ---
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.utils.data import Dataset
from PIL import Image
import glob
import gc
import torch.nn.functional as F

print("Generative AI Setup Complete.")

In [None]:
# --- CELL 9: Production Generative Training (Multi-GPU T4) ---

# --- 1. CONFIGURATION ---
GEN_EPOCHS = 500   
LR = 1e-4

# Check Device & Multi-GPU
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    GPU_COUNT = torch.cuda.device_count()
    print(f"Training on: {GPU_COUNT} x NVIDIA GPU(s)")
    BATCH_SIZE = 32 * GPU_COUNT # 64 total
else:
    DEVICE = torch.device("cpu")
    GPU_COUNT = 0
    BATCH_SIZE = 16
    print("Training on: CPU (Not Recommended)")

if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")

# --- 2. ROBUST CLASSES ---
class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=256): 
        super().__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        original_first = resnet.conv1
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            resnet.conv1.weight[:] = original_first.weight.sum(dim=1, keepdim=True) / 3.0
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        x = self.features(x).view(x.size(0), -1)
        z = self.projection(x)
        return z

class GrayscaleDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.transform = transform
        self.image_paths = glob.glob(f"{image_dir}/**/*.png", recursive=True) + \
                           glob.glob(f"{image_dir}/**/*.jpg", recursive=True)
        print(f"Generative Dataset: {len(self.image_paths)} images.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Random sampling for robustness
        path = np.random.choice(self.image_paths)
        try:
            img = Image.open(path).convert('L')
            if self.transform:
                img = self.transform(img)
            return img
        except:
            return torch.zeros(1, 64, 64)

# --- 3. INITIALIZE ---
print("Initializing Models...")
unet = UNet2DModel(
    sample_size=64, in_channels=1, out_channels=1, layers_per_block=2,
    block_out_channels=(64, 128, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    class_embed_type="identity"
).to(DEVICE)

encoder = SemanticEncoder(latent_dim=256).to(DEVICE)

if GPU_COUNT > 1:
    print(f"Activating DataParallel on {GPU_COUNT} GPUs...")
    unet = nn.DataParallel(unet)
    encoder = nn.DataParallel(encoder)

scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(list(unet.parameters()) + list(encoder.parameters()), lr=LR)
scaler = GradScaler()

# --- 4. LOAD DATA ---
IMAGE_ROOT = '/kaggle/input/knee-osteoarthritis-dataset-with-severity' 

gen_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = GrayscaleDataset(IMAGE_ROOT, transform=gen_transform)

# OPTIMIZED: Set num_workers to 4 (Max for Kaggle CPU) to avoid freezing
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4*GPU_COUNT)

# --- 5. TRAIN ---
print(f"Starting Production Training ({GEN_EPOCHS} Epochs)...")

for epoch in range(GEN_EPOCHS):
    unet.train()
    encoder.train()
    total_loss = 0
    
    progress = tqdm(loader, desc=f"Epoch {epoch+1}/{GEN_EPOCHS}", leave=False)
    
    for images in progress:
        images = images.to(DEVICE)
        optimizer.zero_grad()
        
        with autocast():
            z = encoder(images)
            noise = torch.randn_like(images)
            t = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device=DEVICE).long()
            noisy_images = scheduler.add_noise(images, noise, t)
            
            # return_dict=False handles DataParallel output correctly
            noise_pred = unet(noisy_images, t, class_labels=z, return_dict=False)[0]
            
            loss = F.mse_loss(noise_pred, noise)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress.set_postfix({"loss": loss.item()})
        
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

    # Checkpoint every 50 epochs
    if (epoch + 1) % 50 == 0:
        unet_save = unet.module if isinstance(unet, nn.DataParallel) else unet
        enc_save = encoder.module if isinstance(encoder, nn.DataParallel) else encoder
        
        torch.save(unet_save.state_dict(), f"checkpoints/unet_epoch_{epoch+1}.pth")
        torch.save(enc_save.state_dict(), f"checkpoints/encoder_epoch_{epoch+1}.pth")
        print(f"--> Saved Checkpoint {epoch+1}")

# Final Save
unet_save = unet.module if isinstance(unet, nn.DataParallel) else unet
enc_save = encoder.module if isinstance(encoder, nn.DataParallel) else encoder
torch.save(unet_save.state_dict(), "diffusion_unet.pth")
torch.save(enc_save.state_dict(), "semantic_encoder.pth")
print("Training Complete.")