In [1]:
# ==================================================================================
# L-DR-MEP: Latent Diabetic Retinopathy Markovian Entropy Processor
# Resolution: 512x512 | Visualization: 5-Plot Grid (Classic) | Resume & Train
# ==================================================================================

import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.manifold import TSNE
from skimage.metrics import structural_similarity as ssim

# --- INSTALL DEPENDENCIES (Quiet Mode) ---
print("Installing dependencies...")
os.system("pip install -q diffusers transformers accelerate torchmetrics clean-fid")

from diffusers import UNet2DModel, DDIMScheduler, AutoencoderKL

# ==================================================================================
# 1. CONFIGURATION
# ==================================================================================
class Config:
    EXPERIMENT_NAME = "L-DR-MEP_512_GridViz"
    
    # --- DIMENSIONS ---
    IMAGE_SIZE = 512       
    LATENT_SIZE = 64        # 512 / 8 = 64
    
    # --- TRAINING SETTINGS ---
    BATCH_SIZE = 4         
    LEARNING_RATE = 1e-4
    EPOCHS = 15             # Train for 15 ADDITIONAL epochs
    
    SAVE_INTERVAL = 5       
    NUM_TIMESTEPS = 1000    
    GRADIENT_ACCUMULATION = 2
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # --- PATHS ---
    DATA_PATH = "/kaggle/input/aptos2019/train_images/train_images"
    CSV_PATH = "/kaggle/input/aptos2019/train_1.csv" 
    
    # Checkpoint to LOAD
    LOAD_CHECKPOINT_PATH = "/kaggle/input/ldr-120/pytorch/default/1/ldr_mep_512_120.pth"
    
    # Where to SAVE
    OUTPUT_DIR = "/kaggle/working/L-DR-MEP_Output"
    SAVE_CHECKPOINT_PATH = f"{OUTPUT_DIR}/ldr_mep_512_finetuned.pth"
    
    # Evaluation Directories
    EVAL_DIR = f"{OUTPUT_DIR}/evaluation_plots"
    PROGRESS_DIR = f"{OUTPUT_DIR}/training_progress"

os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
os.makedirs(Config.EVAL_DIR, exist_ok=True)
os.makedirs(Config.PROGRESS_DIR, exist_ok=True)

# Path Verification
if not os.path.exists(Config.CSV_PATH):
    if os.path.exists("/kaggle/input/aptos2019/train.csv"):
        Config.CSV_PATH = "/kaggle/input/aptos2019/train.csv"

print(f"‚öôÔ∏è Device: {Config.DEVICE}")

# ==================================================================================
# 2. DATASET & PREPROCESSING
# ==================================================================================
def crop_image_from_gray(img, tol=7):
    if img.ndim == 2: mask = img > tol
    else: mask = img[:,:,0] > tol
    return img[np.ix_(mask.any(1),mask.any(0))]

def circle_crop_and_resize(img_path, size=512):
    img = cv2.imread(img_path)
    if img is None: return None
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = crop_image_from_gray(img)
    img = cv2.resize(img, (size, size))
    height, width, _ = img.shape
    x, y, r = int(width/2), int(height/2), int(np.amin((width/2, height/2)))
    circle_img = np.zeros((height, width), np.uint8)
    cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
    img = cv2.bitwise_and(img, img, mask=circle_img)
    img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), 30), -4, 128)
    return Image.fromarray(img)

class APTOSDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.id_col = self.df.columns[0] 
        self.label_col = self.df.columns[1]
        self.transform = T.Compose([
            T.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        img_id = str(self.df.iloc[idx][self.id_col])
        if not img_id.endswith(".png"): img_id += ".png"
        img_name = os.path.join(self.root_dir, img_id)
        image = circle_crop_and_resize(img_name, size=Config.IMAGE_SIZE)
        if image is None: image = Image.new('RGB', (Config.IMAGE_SIZE, Config.IMAGE_SIZE))
        return self.transform(image), int(self.df.iloc[idx][self.label_col])

print("Loading Data...")
dataset = APTOSDataset(Config.CSV_PATH, Config.DATA_PATH)
dataloader = DataLoader(
    dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True if Config.DEVICE == "cuda" else False,
    prefetch_factor=2 if Config.DEVICE == "cuda" else None
)

# ==================================================================================
# 3. MODEL SETUP (512x512 | Scaled Linear)
# ==================================================================================
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(Config.DEVICE)
vae.requires_grad_(False)

# Deep U-Net for 64x64 Latents (128->256->512->512)
model = UNet2DModel(
    sample_size=Config.LATENT_SIZE, 
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512), 
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    class_embed_type="timestep", 
    num_class_embeds=5
).to(Config.DEVICE)

# SCHEDULER: Scaled Linear (Optimum for 512px)
scheduler = DDIMScheduler(
    num_train_timesteps=Config.NUM_TIMESTEPS,
    beta_schedule="scaled_linear", 
    clip_sample=False,
    prediction_type="epsilon"
)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

# ==================================================================================
# 4. LOAD CHECKPOINT
# ==================================================================================
def load_checkpoint():
    start_epoch = 0
    if os.path.exists(Config.LOAD_CHECKPOINT_PATH):
        print(f"üì• Loading Checkpoint from {Config.LOAD_CHECKPOINT_PATH}...")
        try:
            checkpoint = torch.load(Config.LOAD_CHECKPOINT_PATH, map_location=Config.DEVICE)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"‚úÖ Successfully loaded. Resume from Global Epoch {start_epoch}.")
        except Exception as e:
            print(f"‚ùå Error loading: {e}. Starting from scratch.")
            start_epoch = 0
    else:
        print(f"‚ö†Ô∏è No checkpoint found. Starting from scratch.")
    return start_epoch

def save_checkpoint(epoch, loss):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(state, Config.SAVE_CHECKPOINT_PATH)
    print(f"üíæ Checkpoint saved at Epoch {epoch+1}")

# ==================================================================================
# 5. VISUALIZATION: THE CLASSIC 5-PLOT GRID
# ==================================================================================
@torch.no_grad()
def generate_images(model, n_samples=1, class_label=2, steps=50, return_intermediates=False):
    model.eval()
    latents = torch.randn((n_samples, 4, Config.LATENT_SIZE, Config.LATENT_SIZE), device=Config.DEVICE)
    labels = torch.tensor([class_label] * n_samples, device=Config.DEVICE)
    scheduler.set_timesteps(steps)
    
    intermediates = []
    # Capture exactly 5 steps: 0%, 25%, 50%, 75%, 100%
    capture_indices = [0, int(steps*0.25), int(steps*0.50), int(steps*0.75), steps-1]
    
    for i, t in enumerate(scheduler.timesteps):
        noise_pred = model(latents, t, class_labels=labels).sample
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        
        if return_intermediates and i in capture_indices:
            decoded = vae.decode(latents / 0.18215).sample
            decoded = (decoded / 2 + 0.5).clamp(0, 1).cpu()
            intermediates.append(decoded)

    final_images = vae.decode(latents / 0.18215).sample
    final_images = (final_images / 2 + 0.5).clamp(0, 1).cpu()
    
    if return_intermediates:
        return final_images, intermediates
    return final_images

def plot_evolution_grid(epoch, title_prefix="Diffusion"):
    """
    Plots the classic 2-row, 5-column grid showing evolution at 0, 25, 50, 75, 100%
    """
    classes_to_show = np.random.choice(range(5), 2, replace=False)
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    for idx, cls in enumerate(classes_to_show):
        _, intermediates = generate_images(model, n_samples=1, class_label=cls, steps=50, return_intermediates=True)
        # intermediates is list of tensors [1, 3, 512, 512]
        for step_idx, img_tensor in enumerate(intermediates):
            img = img_tensor[0].permute(1, 2, 0).numpy()
            axes[idx, step_idx].imshow(img)
            axes[idx, step_idx].axis('off')
            if idx == 0:
                axes[idx, step_idx].set_title(f"Step {step_idx * 25}%")
        axes[idx, 0].set_ylabel(f"Class {cls}", fontsize=12, fontweight='bold')
    
    plt.suptitle(f"{title_prefix} Evolution (Global Epoch {epoch})")
    plt.tight_layout()
    save_path = f"{Config.PROGRESS_DIR}/epoch_{epoch}_evolution.png"
    plt.savefig(save_path)
    plt.close()
    print(f"üñºÔ∏è Evolution Grid saved to {save_path}")

# ==================================================================================
# 6. EXECUTION LOGIC
# ==================================================================================

# 1. Load Model
start_epoch = load_checkpoint()

# 2. VALIDATION: Check the loaded model visually (Using 5-Plot Grid)
print("\nüîé Validating loaded model (Pre-Train Visualization)...")
plot_evolution_grid(start_epoch, title_prefix="LOADED_MODEL_VALIDATION")

# 3. Training Loop
end_epoch = start_epoch + Config.EPOCHS
print(f"\nüöÄ Resuming training from Epoch {start_epoch+1}.")
print(f"üîÑ Target: {Config.EPOCHS} additional epochs (End: {end_epoch}).")

losses = []

for epoch in range(start_epoch, end_epoch):
    model.train()
    progress_bar = tqdm(dataloader, desc=f"Global Epoch {epoch+1}/{end_epoch}")
    epoch_loss = 0
    
    for step, (images, labels) in enumerate(progress_bar):
        images = images.to(Config.DEVICE, non_blocking=True)
        labels = labels.to(Config.DEVICE, non_blocking=True)
        
        # VAE Encode
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215
            
        # Add Noise
        noise = torch.randn_like(latents)
        bs = latents.shape[0]
        timesteps = torch.randint(0, Config.NUM_TIMESTEPS, (bs,), device=Config.DEVICE).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        
        # Train
        noise_pred = model(noisy_latents, timesteps, class_labels=labels).sample
        loss = F.mse_loss(noise_pred, noise)
        
        loss.backward()
        if (step + 1) % Config.GRADIENT_ACCUMULATION == 0:
            optimizer.step()
            optimizer.zero_grad()
            
        epoch_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item()})
        
    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    
    # Save/Plot every interval OR at the very last epoch
    if (epoch + 1) % Config.SAVE_INTERVAL == 0 or (epoch + 1) == end_epoch:
        save_checkpoint(epoch, avg_loss)
        plot_evolution_grid(epoch + 1)

# ==================================================================================
# 7. FINAL EVALUATION
# ==================================================================================
print("\n‚úÖ Training Complete. Generating Metrics...")

# Plot 1: Loss
if len(losses) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(range(start_epoch + 1, end_epoch + 1), losses, marker='o', color='b')
    plt.title("L-DR-MEP Training Loss")
    plt.grid(True)
    plt.savefig(f"{Config.EVAL_DIR}/1_Loss.png")
    plt.close()

# Plot 2: XAI/SSIM
try:
    fig, axes = plt.subplots(5, 3, figsize=(10, 15))
    for cls in range(5):
        if not any(dataset.df[dataset.label_col] == cls): continue
        real_idx = dataset.df[dataset.df[dataset.label_col] == cls].index[0]
        real_img, _ = dataset[real_idx]
        real_img_np = (real_img.permute(1, 2, 0).numpy() + 1) / 2
        
        syn_tensor = generate_images(model, n_samples=1, class_label=cls)[0]
        syn_img_np = syn_tensor.permute(1, 2, 0).numpy()
        
        # SSIM
        real_gray = cv2.cvtColor((real_img_np*255).astype('uint8'), cv2.COLOR_RGB2GRAY)
        syn_gray = cv2.cvtColor((syn_img_np*255).astype('uint8'), cv2.COLOR_RGB2GRAY)
        score, diff = ssim(real_gray, syn_gray, full=True)
        diff_heatmap = cv2.applyColorMap((diff * 255).astype("uint8"), cv2.COLORMAP_JET)
        
        axes[cls, 0].imshow(real_img_np); axes[cls, 0].set_ylabel(f"Class {cls}")
        axes[cls, 1].imshow(syn_img_np)
        axes[cls, 2].imshow(diff_heatmap)
        for j in range(3): axes[cls, j].axis('off')
    plt.savefig(f"{Config.EVAL_DIR}/3_XAI.png"); plt.close()
    print("‚úÖ XAI Plot Saved.")
except Exception as e: print(f"Error in XAI: {e}")

# Plot 3: Spectral
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
for cls in range(5):
    imgs = generate_images(model, n_samples=5, class_label=cls)
    for i, img in enumerate(imgs):
        gray = cv2.cvtColor((img.permute(1,2,0).numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        f = np.fft.fftshift(np.fft.fft2(gray))
        spec = 20 * np.log(np.abs(f) + 1e-8)
        axes[cls, i].imshow(spec, cmap='inferno'); axes[cls, i].axis('off')
plt.savefig(f"{Config.EVAL_DIR}/5_Spectral.png"); plt.close()
print("‚úÖ Spectral Plot Saved.")

# Plot 4: t-SNE
print("Computing t-SNE...")
extractor = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
extractor.fc = nn.Identity()
extractor.to(Config.DEVICE)
extractor.eval()
features = []
labels_list = []

# Real Data
for cls in range(5):
    idxs = dataset.df[dataset.df[dataset.label_col] == cls].index[:20]
    for idx in idxs:
        img, _ = dataset[idx]
        with torch.no_grad():
            feat = extractor(img.unsqueeze(0).to(Config.DEVICE)).cpu().numpy().flatten()
        features.append(feat); labels_list.append(cls)

# Syn Data
for cls in range(5):
    imgs = generate_images(model, n_samples=20, class_label=cls, steps=30)
    imgs = T.Normalize([-0.5], [2.0])(imgs).to(Config.DEVICE)
    for i in range(20):
        with torch.no_grad():
            feat = extractor(imgs[i].unsqueeze(0)).cpu().numpy().flatten()
        features.append(feat); labels_list.append(cls + 5)

if len(features) > 0:
    tsne = TSNE(n_components=2, perplexity=10, random_state=42)
    embedded = tsne.fit_transform(np.array(features))
    plt.figure(figsize=(10, 8))
    colors = ['red', 'green', 'blue', 'orange', 'purple']
    labels = ["No DR", "Mild", "Moderate", "Severe", "Proliferative"]
    for cls in range(5):
        mask = np.array(labels_list) == cls
        if np.any(mask): plt.scatter(embedded[mask, 0], embedded[mask, 1], c=colors[cls], marker='o', alpha=0.6, label=f"Real {labels[cls]}")
    for cls in range(5):
        mask = np.array(labels_list) == (cls + 5)
        if np.any(mask): plt.scatter(embedded[mask, 0], embedded[mask, 1], c=colors[cls], marker='*', s=100, label=f"Syn {labels[cls]}")
    plt.legend()
    plt.savefig(f"{Config.EVAL_DIR}/6_tSNE.png"); plt.close()

print(f"\n‚úÖ All Output in {Config.OUTPUT_DIR}")

Installing dependencies...


2026-01-29 14:44:02.824190: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769697842.987512      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769697843.033965      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769697843.419838      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769697843.419872      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769697843.419875      55 computation_placer.cc:177] computation placer alr

‚öôÔ∏è Device: cuda
Loading Data...


config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

üì• Loading Checkpoint from /kaggle/input/ldr-120/pytorch/default/1/ldr_mep_512_120.pth...
‚úÖ Successfully loaded. Resume from Global Epoch 120.

üîé Validating loaded model (Pre-Train Visualization)...
üñºÔ∏è Evolution Grid saved to /kaggle/working/L-DR-MEP_Output/training_progress/epoch_120_evolution.png

üöÄ Resuming training from Epoch 121.
üîÑ Target: 15 additional epochs (End: 135).


Global Epoch 121/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 122/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 123/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 124/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 125/135:   0%|          | 0/733 [00:00<?, ?it/s]

üíæ Checkpoint saved at Epoch 125
üñºÔ∏è Evolution Grid saved to /kaggle/working/L-DR-MEP_Output/training_progress/epoch_125_evolution.png


Global Epoch 126/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 127/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 128/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 129/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 130/135:   0%|          | 0/733 [00:00<?, ?it/s]

üíæ Checkpoint saved at Epoch 130
üñºÔ∏è Evolution Grid saved to /kaggle/working/L-DR-MEP_Output/training_progress/epoch_130_evolution.png


Global Epoch 131/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 132/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 133/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 134/135:   0%|          | 0/733 [00:00<?, ?it/s]

Global Epoch 135/135:   0%|          | 0/733 [00:00<?, ?it/s]

üíæ Checkpoint saved at Epoch 135
üñºÔ∏è Evolution Grid saved to /kaggle/working/L-DR-MEP_Output/training_progress/epoch_135_evolution.png

‚úÖ Training Complete. Generating Metrics...
‚úÖ XAI Plot Saved.
‚úÖ Spectral Plot Saved.
Computing t-SNE...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 44.7M/44.7M [00:00<00:00, 175MB/s] 


OutOfMemoryError: CUDA out of memory. Tried to allocate 5.00 GiB. GPU 0 has a total capacity of 14.74 GiB of which 3.31 GiB is free. Process 3383 has 11.43 GiB memory in use. Of the allocated memory 5.74 GiB is allocated by PyTorch, and 5.55 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)