In [None]:
# --- CELL 1: Setup, Sync & Install ---
import sys
import os

# 1. Clean up old repo
!rm -rf oa-survival-model

# 2. Clone your repo
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
git_token = user_secrets.get_secret("GIT_TOKEN")
username = "AbhiGTM19"

!git clone https://{username}:{git_token}@github.com/{username}/oa-survival-model.git

# 3. Install ALL dependencies from your updated requirements.txt
# This will now include 'diffusers' and 'transformers' automatically
%cd oa-survival-model
!pip install -r requirements.txt
%cd ..

# 4. Add source code to path
sys.path.append('/kaggle/working/oa-survival-model/src')

print("Environment Ready & Dependencies Installed.")

In [None]:
# --- CELL 2: Imports & Configuration (Multi-GPU Optimized) ---
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import torchsurv.loss
from torch.cuda.amp import GradScaler, autocast # For Mixed Precision

# Import your custom modules
from model import WideAndDeepSurvivalModel
from dataset import TriModalDataset

# Configuration
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    GPU_COUNT = torch.cuda.device_count()
    print(f"Training on: {GPU_COUNT} x NVIDIA GPU(s)")
else:
    DEVICE = torch.device("cpu")
    GPU_COUNT = 0
    print("Training on: CPU")

# Paths
PARQUET_PATH = '/kaggle/input/oai-preprocessed-data/OAI_tri_modal_real.parquet' 
IMAGE_ROOT = '/kaggle/input/knee-osteoarthritis-dataset-with-severity'

# Hyperparameters
# Scale batch size by number of GPUs (32 per GPU -> 64 total)
BATCH_SIZE = 32 * max(1, GPU_COUNT) 
EPOCHS = 10
LEARNING_RATE = 1e-4

In [None]:
# --- CELL 3: Data Prep (Exact copy of your local logic) ---

# 1. Load Data
df = pd.read_parquet(PARQUET_PATH)

# 2. Preprocessing
# We recreate the columns manually for simplicity in this test script
df = pd.get_dummies(df, columns=['KL_Grade', 'Sex'], drop_first=True)
expected_cols = ['KL_Grade_1.0', 'KL_Grade_2.0', 'KL_Grade_3.0', 'KL_Grade_4.0', 'Sex_2']
for col in expected_cols:
    if col not in df.columns:
        df[col] = 0

# 3. Split
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# 4. Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(), # Augmentation
    transforms.RandomRotation(10),     # Augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 5. Datasets & Loaders
train_dataset = TriModalDataset(train_df, IMAGE_ROOT, transform=train_transform, mode='sandbox')
val_dataset = TriModalDataset(val_df, IMAGE_ROOT, transform=val_transform, mode='sandbox')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

In [None]:
# --- CELL 4: Tri-Modal Survival Training ---

# 1. Initialize Model (With Biomarker Input)
# Note: bio_input_dim=5 matches your 5 chosen markers
model = WideAndDeepSurvivalModel(wide_input_dim=8, bio_input_dim=5).to(DEVICE)

if GPU_COUNT > 1:
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

# ... (Loss function definition stays the same) ...

# 3. Training Loop
print("Starting Tri-Modal Training...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    # UPDATE UNPACKING: Now getting 5 items
    for batch_idx, (images, clinical, bio, events, times) in enumerate(train_loader):
        if events.sum() == 0: continue

        # Move ALL 3 inputs to GPU
        images = images.to(DEVICE)
        clinical = clinical.to(DEVICE)
        bio = bio.to(DEVICE) # <--- New Input
        events = events.to(DEVICE)
        times = times.to(DEVICE)
        
        optimizer.zero_grad()
        
        with autocast():
            # Pass ALL 3 inputs to model
            risk_scores = model(images, clinical, bio).squeeze()
            loss = cox_loss_func(risk_scores, events, times)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{EPOCHS} | Avg Loss: {total_loss/len(train_loader):.4f}")

# Save the Tri-Modal Model
torch.save(model.module.state_dict() if GPU_COUNT > 1 else model.state_dict(), "tri_modal_survival_model.pth")
print("Saved Tri-Modal Model.")

In [None]:
# --- CELL 5: Evaluation ---
from sksurv.metrics import concordance_index_censored
import numpy as np

print("Evaluating on Validation Set...")
model.eval()

val_risk_scores = []
val_events = []
val_times = []

# 1. Collect predictions
with torch.no_grad():
    for images, clinical, events, times in val_loader:
        images, clinical = images.to(DEVICE), clinical.to(DEVICE)
        
        # Forward pass
        outputs = model(images, clinical).squeeze()
        
        # Store results (move to CPU for sksurv)
        val_risk_scores.extend(outputs.cpu().numpy())
        val_events.extend(events.numpy().astype(bool)) # sksurv expects boolean
        val_times.extend(times.numpy())

# 2. Calculate C-index
c_index = concordance_index_censored(
    np.array(val_events),
    np.array(val_times),
    np.array(val_risk_scores)
)

print(f"Validation C-index: {c_index[0]:.4f}")
print(f"Baseline Target:    0.7468")

In [None]:
# --- CELL 6: Save Model ---
SAVE_PATH = "tri_modal_survival_model.pth"
torch.save(model.state_dict(), SAVE_PATH)
print(f"Model saved to: {SAVE_PATH}")

In [None]:
# --- CELL 7: Improved Generative AI Setup (Grayscale) ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import glob

# Redefine Encoder for 1-Channel Input
class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=256): 
        super().__init__()
        # ResNet expects 3 channels. We modify the first layer to take 1 channel.
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # The magic fix: Change input layer from 3 channels to 1
        original_first_layer = resnet.conv1
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Initialize the new 1-channel weights by averaging the old RGB weights
        with torch.no_grad():
            resnet.conv1.weight[:] = original_first_layer.weight.sum(dim=1, keepdim=True) / 3.0
            
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        z = self.projection(x)
        return z

print("Generative AI Architecture Updated (Grayscale Enabled).")

In [None]:
# --- CELL 8: Generative AI Training (Robust Self-Contained Version) ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import numpy as np
import os
import glob
from torch.cuda.amp import GradScaler, autocast

# --- 1. CONFIGURATION ---
GEN_EPOCHS = 500   # The "Magic Number" for high quality
LR = 1e-4
BATCH_SIZE_PER_GPU = 32

# Check Device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    GPU_COUNT = torch.cuda.device_count()
    print(f"Training on: {GPU_COUNT} x NVIDIA GPU(s)")
else:
    DEVICE = torch.device("cpu")
    GPU_COUNT = 1
    print("Training on: CPU")

# Create Checkpoint Folder
if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")

# --- 2. ROBUST CLASS DEFINITIONS ---
# We define these HERE to ensure the Grayscale logic is 100% active
# and to avoid import errors if model.py isn't synced.

class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=256): 
        super().__init__()
        # Load ResNet
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # --- GRAYSCALE FIX (3 Channels -> 1 Channel) ---
        original_first_layer = resnet.conv1
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            # Average the weights to keep pre-trained patterns
            resnet.conv1.weight[:] = original_first_layer.weight.sum(dim=1, keepdim=True) / 3.0
            
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        z = self.projection(x)
        return z

class GrayscaleDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.df = dataframe
        self.transform = transform
        # Find all images recursively
        self.all_image_paths = glob.glob(f"{image_dir}/**/*.png", recursive=True) + \
                               glob.glob(f"{image_dir}/**/*.jpg", recursive=True)
        print(f"Grayscale Dataset: Found {len(self.all_image_paths)} images.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Robust Training Strategy: Random Sampling
        # This ensures the model sees ALL variations, not just the limited rows in df
        img_path = np.random.choice(self.all_image_paths)
        
        try:
            # Force Convert to Grayscale ('L')
            image = Image.open(img_path).convert('L') 
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            # Fallback for bad images
            return torch.zeros(1, 64, 64)

# --- 3. INITIALIZE MODELS ---
print("Initializing Grayscale Models...")
unet = UNet2DModel(
    sample_size=64,  
    in_channels=1,   # Grayscale Input
    out_channels=1,  # Grayscale Output
    layers_per_block=2,
    block_out_channels=(64, 128, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    class_embed_type="identity" # Allows conditioning on 'z'
).to(DEVICE)

encoder = SemanticEncoder(latent_dim=256).to(DEVICE)

# Multi-GPU Wrapping
if GPU_COUNT > 1:
    print(">> Activating DataParallel for UNet & Encoder")
    unet = nn.DataParallel(unet)
    encoder = nn.DataParallel(encoder)

# Optimizer & Scheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(list(unet.parameters()) + list(encoder.parameters()), lr=LR)
scaler = GradScaler() # Mixed Precision for Speed

# --- 4. DATA LOADING ---
gen_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])

# Use the paths defined in Cell 2
gen_dataset = GrayscaleDataset(train_df, IMAGE_ROOT, transform=gen_transform)
gen_loader = DataLoader(
    gen_dataset, 
    batch_size=BATCH_SIZE_PER_GPU * max(1, GPU_COUNT), # Scale batch size
    shuffle=True, 
    num_workers=2
)

# --- 5. TRAINING LOOP ---
print(f"Starting Training: {GEN_EPOCHS} Epochs | Batch Size: {BATCH_SIZE_PER_GPU * max(1, GPU_COUNT)}")

for epoch in range(GEN_EPOCHS):
    unet.train()
    encoder.train()
    total_loss = 0
    
    # Progress bar for this epoch
    progress = tqdm(gen_loader, desc=f"Epoch {epoch+1}/{GEN_EPOCHS}", leave=False)
    
    for images in progress:
        images = images.to(DEVICE)
        batch_size = images.shape[0]
        
        optimizer.zero_grad()
        
        # Mixed Precision Context (Faster on T4)
        with autocast():
            # 1. Encode Image -> Vector z
            z = encoder(images)
            
            # 2. Add Noise to Image
            noise = torch.randn_like(images)
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=DEVICE).long()
            noisy_images = scheduler.add_noise(images, noise, timesteps)
            
            # 3. Predict Noise (Conditioned on z)
            # return_dict=False fixes the DataParallel tuple issue
            noise_pred = unet(noisy_images, timestep=timesteps, class_labels=z, return_dict=False)[0]
            
            # 4. Calculate Loss
            loss = F.mse_loss(noise_pred, noise)
        
        # Backward & Step (Scaled)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress.set_postfix({"loss": loss.item()})
        
    avg_loss = total_loss / len(gen_loader)
    print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

    # --- SAVE CHECKPOINT (Every 50 Epochs) ---
    if (epoch + 1) % 50 == 0:
        # Handle unwrapping DataParallel for saving
        unet_state = unet.module.state_dict() if isinstance(unet, nn.DataParallel) else unet.state_dict()
        enc_state = encoder.module.state_dict() if isinstance(encoder, nn.DataParallel) else encoder.state_dict()
        
        torch.save(unet_state, f"checkpoints/unet_epoch_{epoch+1}.pth")
        torch.save(enc_state, f"checkpoints/encoder_epoch_{epoch+1}.pth")
        print(f"--> Checkpoint saved at Epoch {epoch+1}")

# Final Save
unet_state = unet.module.state_dict() if isinstance(unet, nn.DataParallel) else unet.state_dict()
enc_state = encoder.module.state_dict() if isinstance(encoder, nn.DataParallel) else encoder.state_dict()
torch.save(unet_state, "diffusion_unet.pth")
torch.save(enc_state, "semantic_encoder.pth")
print("Generative Training Finished Successfully.")

In [None]:
# --- CELL 9: Visualization (Grayscale) ---

def generate_image(ref_image, modification=0.0):
    unet.eval()
    encoder.eval()
    with torch.no_grad():
        ref_image = ref_image.unsqueeze(0).to(DEVICE)
        z = encoder(ref_image)
        
        # Simulate Counterfactual
        z_modified = z + (torch.randn_like(z) * modification)
        
        # Generate
        image = torch.randn_like(ref_image)
        for t in scheduler.timesteps:
            out = unet(image, t, class_labels=z_modified).sample
            image = scheduler.step(out, t, image).prev_sample
            
    return image.cpu().squeeze()

# Pick a sample
sample = gen_dataset[0]
recon = generate_image(sample, modification=0.0)
counterfactual = generate_image(sample, modification=1.0)

# Display in Grayscale
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(sample.permute(1, 2, 0).squeeze(), cmap='gray')
axs[0].set_title("Original X-Ray")
axs[1].imshow(recon.squeeze(), cmap='gray')
axs[1].set_title("AI Reconstruction")
axs[2].imshow(counterfactual.squeeze(), cmap='gray')
axs[2].set_title("Counterfactual")
plt.show()

# Save models
torch.save(unet.state_dict(), "diffusion_unet.pth")
torch.save(encoder.state_dict(), "semantic_encoder.pth")
print("Grayscale models saved.")