In [1]:
# --- CELL 1: Install & Setup ---
!pip install diffusers["torch"] transformers

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import sys
import os

# Setup Paths & Device
sys.path.append('/kaggle/working/oa-survival-model/src')
from dataset import TriModalDataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {DEVICE}")

# Kaggle Paths
PARQUET_PATH = '/kaggle/input/oai-preprocessed-data/OAI_model_ready_data.parquet'
IMAGE_ROOT = '/kaggle/input/knee-osteoarthritis-dataset-with-severity'

zsh:1: no matches found: diffusers[torch]


ModuleNotFoundError: No module named 'diffusers'

In [None]:
# --- CELL 2: Semantic Encoder Architecture ---
import torchvision.models as models

class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=512):
        super().__init__()
        # Use ResNet18 backbone
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        # Remove the final FC layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        # Add a projection to our desired latent dimension
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        # x: [Batch, 3, 224, 224]
        x = self.features(x) # -> [Batch, 512, 1, 1]
        x = x.view(x.size(0), -1) # -> [Batch, 512]
        z = self.projection(x)    # -> [Batch, latent_dim]
        return z

print("Semantic Encoder defined.")

In [None]:
# --- CELL 3: Initialize Diffusion Components ---

# 1. The Noise Scheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)

# 2. The UNet (Generator)
# We use a standard 2D UNet outputting 3 channels (RGB)
# We condition it on the 'class_labels' which will actually be our Latent Vector z
unet = UNet2DModel(
    sample_size=64,  # We process at 64x64 for speed (or 128 if GPU allows)
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256),
    down_block_types=(
        "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"
    ),
    up_block_types=(
        "UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"
    ),
    # This enables conditioning on a vector!
    class_embed_type="identity" 
).to(DEVICE)

# 3. The Encoder
encoder = SemanticEncoder(latent_dim=128).to(DEVICE) # Match UNet embedding dim

# 4. Optimizer (Train both)
optimizer = torch.optim.Adam(list(unet.parameters()) + list(encoder.parameters()), lr=1e-4)

print("Diffusion Model & Encoder initialized.")

In [None]:
# --- CELL 4: Data Prep & Training ---

# 1. Data Loader (Same as before, simplified transforms)
# We resize to 64x64 for faster Diffusion training on T4
train_transform = transforms.Compose([
    transforms.Resize((64, 64)), 
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # Diffusion models prefer -1 to 1 range
])

# Reuse your dataframe logic from previous notebooks to get 'df'
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.read_parquet(PARQUET_PATH)
train_df, _ = train_test_split(df, test_size=0.1)

dataset = TriModalDataset(train_df, IMAGE_ROOT, transform=train_transform, mode='sandbox')
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 2. Training Loop
EPOCHS = 5 # Keep it short for the prototype

print("Starting Diffusion Training...")

for epoch in range(EPOCHS):
    unet.train()
    encoder.train()
    total_loss = 0
    
    for images, _, _, _ in tqdm(loader):
        images = images.to(DEVICE)
        batch_size = images.shape[0]
        
        # A. Encode Image to Semantic Latent z
        # This vector 'z' captures "Patient Anatomy"
        z = encoder(images) 
        
        # B. Sample Noise
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=DEVICE).long()
        
        # C. Add Noise (Forward Diffusion)
        noisy_images = scheduler.add_noise(images, noise, timesteps)
        
        # D. Predict Noise (Reverse Diffusion)
        # Crucial: We pass 'z' as class_labels to condition the generation
        noise_pred = unet(noisy_images, timestep=timesteps, class_labels=z).sample
        
        # E. Loss (MSE between real noise and predicted noise)
        loss = F.mse_loss(noise_pred, noise)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(loader):.4f}")

In [None]:
# --- CELL 5: Generate Counterfactual ---

def generate_counterfactual(original_image, modification_factor=0.0):
    # 1. Encode Real Image -> Latent z
    with torch.no_grad():
        original_image = original_image.unsqueeze(0).to(DEVICE) # Batch of 1
        z = encoder(original_image)
        
        # 2. MODIFY Latent z (The Counterfactual Step)
        # In a real scenario, we would use gradients from the survival model.
        # Here, we simulate a change by adding random noise for demonstration.
        z_modified = z + (torch.randn_like(z) * modification_factor)

        # 3. Generate Image from Modified z
        # Start from pure noise
        generated_image = torch.randn_like(original_image)
        
        # Denoise loop
        for t in scheduler.timesteps:
            # Expand z for the batch
            model_output = unet(generated_image, t, class_labels=z_modified).sample
            generated_image = scheduler.step(model_output, t, generated_image).prev_sample

    return generated_image.cpu().squeeze()

# Test it
# Get a sample image
sample_img, _, _, _ = dataset[0] 
# Generate: 0.0 modification = Reconstruction, 0.5 = Counterfactual
recon_img = generate_counterfactual(sample_img, modification_factor=0.0)
cf_img = generate_counterfactual(sample_img, modification_factor=1.0)

# Visualization
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(sample_img.permute(1, 2, 0) * 0.5 + 0.5)
axs[0].set_title("Original Patient X-Ray")
axs[1].imshow(recon_img.permute(1, 2, 0) * 0.5 + 0.5)
axs[1].set_title("AI Reconstruction (Sanity Check)")
axs[2].imshow(cf_img.permute(1, 2, 0) * 0.5 + 0.5)
axs[2].set_title("Survival Counterfactual (Modified)")
plt.show()