In [2]:
# --- CELL 1: Local Setup (Mac M1) ---
import sys
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

# Diffusers imports
from diffusers import UNet2DModel, DDPMScheduler
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# 1. Setup Paths
# Add '../src' to the python path so we can import our modules
sys.path.append(os.path.abspath('../src'))

# 2. Verify Imports
try:
    from dataset import TriModalDataset
    print("Success: 'src' module loaded.")
except ImportError as e:
    print(f"Error: Could not import from src. {e}")

# 3. Setup Device (Use Metal Performance Shaders for M1)
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Training on: Apple M1 GPU (MPS)")
else:
    DEVICE = torch.device("cpu")
    print("Training on: CPU (Warning: Slow)")

# 4. Define Local Data Paths
PARQUET_PATH = '../data/processed/OAI_model_ready_data.parquet'
IMAGE_ROOT = '../data/sandbox'

Success: 'src' module loaded.
Training on: Apple M1 GPU (MPS)


In [3]:
# --- CELL 2: Generative AI Setup (Grayscale) ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import glob

# Redefine Encoder for 1-Channel Input
class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=256): 
        super().__init__()
        # ResNet expects 3 channels. We modify the first layer to take 1 channel.
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # The magic fix: Change input layer from 3 channels to 1
        original_first_layer = resnet.conv1
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Initialize the new 1-channel weights by averaging the old RGB weights
        with torch.no_grad():
            resnet.conv1.weight[:] = original_first_layer.weight.sum(dim=1, keepdim=True) / 3.0
            
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        z = self.projection(x)
        return z

print("Generative AI Architecture Updated (Grayscale Enabled).")

Generative AI Architecture Updated (Grayscale Enabled).


In [None]:
# --- CELL 3: Train Optimized Diffusion Model ---

# 1. Configuration
GEN_EPOCHS = 50  # Increased for better quality
LR = 1e-4

# 2. Initialize Models (1-Channel Configuration)
unet = UNet2DModel(
    sample_size=64,  
    in_channels=1,   # CHANGED: Grayscale
    out_channels=1,  # CHANGED: Grayscale
    layers_per_block=2,
    block_out_channels=(64, 128, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    class_embed_type="identity"
).to(DEVICE)

encoder = SemanticEncoder(latent_dim=256).to(DEVICE)
scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(list(unet.parameters()) + list(encoder.parameters()), lr=LR)

# 3. Transforms (1-Channel)
gen_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]) # Mean/Std for 1 channel
])

class GrayscaleDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.df = dataframe
        self.transform = transform
        # Get all images using glob (recursively)
        import glob
        self.all_image_paths = glob.glob(f"{image_dir}/**/*.png", recursive=True) + \
                               glob.glob(f"{image_dir}/**/*.jpg", recursive=True)
        print(f"Grayscale Dataset: Found {len(self.all_image_paths)} images.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Sandbox Mode Logic: Pick Random Image
        img_path = np.random.choice(self.all_image_paths)
        
        # Force Grayscale ('L')
        from PIL import Image
        image = Image.open(img_path).convert('L') 
        if self.transform:
            image = self.transform(image)
        
        return image

# Load Data (Local Paths)
df = pd.read_parquet('../data/processed/OAI_model_ready_data.parquet')
train_df, _ = train_test_split(df, test_size=0.1, random_state=42)

# Initialize Dataset & Loader
gen_dataset = GrayscaleDataset(train_df.head(200), '../data/sandbox', transform=gen_transform)
gen_loader = DataLoader(gen_dataset, batch_size=32, shuffle=True) 
# Note: We use a small subset (head(200)) for local speed test

# 5. Training Loop
print(f"Starting Grayscale Training ({GEN_EPOCHS} Epochs)...")

for epoch in range(GEN_EPOCHS):
    unet.train()
    encoder.train()
    total_loss = 0
    
    progress = tqdm(gen_loader, desc=f"Epoch {epoch+1}/{GEN_EPOCHS}", leave=False)
    
    for images in progress:
        images = images.to(DEVICE) # Shape: [Batch, 1, 64, 64]
        batch_size = images.shape[0]
        
        # Encode
        z = encoder(images)
        
        # Add Noise
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=DEVICE).long()
        noisy_images = scheduler.add_noise(images, noise, timesteps)
        
        # Predict Noise
        noise_pred = unet(noisy_images, timestep=timesteps, class_labels=z).sample
        
        # Loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        progress.set_postfix({"loss": loss.item()})
        
    print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(gen_loader):.4f}")

print("Training Complete.")

Grayscale Dataset: Found 9786 images.
Starting Grayscale Training (50 Epochs)...


Epoch 1/50:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# --- CELL 4: Visualization (Grayscale) ---

def generate_image(ref_image, modification=0.0):
    unet.eval()
    encoder.eval()
    with torch.no_grad():
        ref_image = ref_image.unsqueeze(0).to(DEVICE)
        z = encoder(ref_image)
        
        # Simulate Counterfactual
        z_modified = z + (torch.randn_like(z) * modification)
        
        # Generate
        image = torch.randn_like(ref_image)
        for t in scheduler.timesteps:
            out = unet(image, t, class_labels=z_modified).sample
            image = scheduler.step(out, t, image).prev_sample
            
    return image.cpu().squeeze()

# Pick a sample
sample = gen_dataset[0]
recon = generate_image(sample, modification=0.0)
counterfactual = generate_image(sample, modification=1.0)

# Display in Grayscale
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(sample.permute(1, 2, 0).squeeze(), cmap='gray')
axs[0].set_title("Original X-Ray")
axs[1].imshow(recon.squeeze(), cmap='gray')
axs[1].set_title("AI Reconstruction")
axs[2].imshow(counterfactual.squeeze(), cmap='gray')
axs[2].set_title("Counterfactual")
plt.show()

# Save models
torch.save(unet.state_dict(), "diffusion_unet.pth")
torch.save(encoder.state_dict(), "semantic_encoder.pth")
print("Grayscale models saved.")