In [None]:
# ============================================================
# 1. SETUP
# ============================================================

import os
import cv2
import glob
import math
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import torchvision.transforms as T


# ============================================================
# RANDOM SEED
# ============================================================

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True


seed_everything(42)


# ============================================================
# UPDATED DATASET PATHS (YOUR WINDOWS PATH)
# ============================================================

ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"

TRAIN_DIR = os.path.join(ROOT, "train")
VAL_DIR   = os.path.join(ROOT, "val")
TEST_DIR  = os.path.join(ROOT, "test")

print("Train folder:", TRAIN_DIR)
print("Val folder:", VAL_DIR)
print("Test folder:", TEST_DIR)

if not os.path.exists(TRAIN_DIR):
    raise FileNotFoundError("Train directory not found")

if not os.path.exists(VAL_DIR):
    raise FileNotFoundError("Validation directory not found")

if not os.path.exists(TEST_DIR):
    raise FileNotFoundError("Test directory not found")

print("Dataset folders verified ✓")


# ============================================================
# OUTPUT FOLDERS
# ============================================================

OUT_ROOT = r"D:\Rushi_OCT_Diffusion\oct_ldm_output"

DIR_SAMPLES      = os.path.join(OUT_ROOT, "samples")
DIR_CHECKPOINTS  = os.path.join(OUT_ROOT, "checkpoints")
DIR_PLOTS        = os.path.join(OUT_ROOT, "plots")
DIR_JSON         = os.path.join(OUT_ROOT, "json")
DIR_CSV          = os.path.join(OUT_ROOT, "csv")

os.makedirs(DIR_SAMPLES, exist_ok=True)
os.makedirs(DIR_CHECKPOINTS, exist_ok=True)
os.makedirs(DIR_PLOTS, exist_ok=True)
os.makedirs(DIR_JSON, exist_ok=True)
os.makedirs(DIR_CSV, exist_ok=True)

print("Output folders created at:", OUT_ROOT)


# ============================================================
# DEVICE
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


# ============================================================
# VISUALIZATION FUNCTION
# ============================================================

def show_grid(tensor, nrow=8, title=""):
    tensor = tensor.detach()
    grid = make_grid(tensor.clamp(-1, 1), nrow=nrow, normalize=False)
    grid = grid.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(8, 8))

    if grid.shape[-1] == 1:
        plt.imshow(grid[..., 0], cmap="gray")
    else:
        plt.imshow(grid)

    if title:
        plt.title(title)

    plt.axis("off")
    plt.tight_layout()
    plt.show()


print("OCT LDM environment ready ✓")

In [None]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader


# ============================================================
# DATASET CLASS (OPTIMIZED)
# ============================================================

class OCTDataset(Dataset):

    def __init__(self, root, image_size=512):

        if not os.path.exists(root):
            raise FileNotFoundError(f"Dataset folder not found: {root}")

        self.image_size = image_size
        self.paths = []

        # FAST directory scan (faster than glob recursive)
        for dirpath, _, files in os.walk(root):
            for f in files:
                if f.lower().endswith(".png"):
                    self.paths.append(os.path.join(dirpath, f))

        if not self.paths:
            raise RuntimeError(f"No PNG files found in {root}")

        print(f"Loaded {len(self.paths)} images from {root}")


    def __len__(self):
        return len(self.paths)


    def __getitem__(self, idx):

        path = self.paths[idx]
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)

        if img is None:
            raise RuntimeError(f"Failed to load image: {path}")

        h, w = img.shape
        t = self.image_size

        scale = min(t / h, t / w)
        nh, nw = int(h * scale), int(w * scale)

        img = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_AREA)

        pad_h, pad_w = t - nh, t - nw

        img = cv2.copyMakeBorder(
            img,
            pad_h // 2, pad_h - pad_h // 2,
            pad_w // 2, pad_w - pad_w // 2,
            cv2.BORDER_REFLECT_101
        )

        img = torch.from_numpy(img).float().unsqueeze(0)

        img = img / 127.5 - 1.0

        patient = os.path.basename(os.path.dirname(path))

        return {
            "image": img,
            "patient": patient,
            "path": path
        }


# ============================================================
# PATHS
# ============================================================

DATA_ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"

TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR   = os.path.join(DATA_ROOT, "val")
TEST_DIR  = os.path.join(DATA_ROOT, "test")

print("Using dataset:", DATA_ROOT)


# ============================================================
# LOADERS
# ============================================================

def build_loader(path, batch_size=4, shuffle=False):

    ds = OCTDataset(path)

    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,  # notebook safe
        pin_memory=torch.cuda.is_available()
    )


train_loader = build_loader(TRAIN_DIR, shuffle=True)
val_loader   = build_loader(VAL_DIR)
test_loader  = build_loader(TEST_DIR)

print("\nLoaders ready ✓")
print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))


# ============================================================
# SANITY TEST
# ============================================================

batch = next(iter(train_loader))

print("\nBatch image shape:", batch["image"].shape)
print("Sample patients:", batch["patient"][:4])

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# PATHS (ALIGNED WITH YOUR PREVIOUS STRUCTURE)
# ============================================================

CHECKPOINT_DIR = r"D:\Rushi_OCT_Diffusion\OCT_Local_pths"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

VAE_SAVE_PATH = os.path.join(CHECKPOINT_DIR, "vae_best.pth")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ============================================================
# Utility: Safe GroupNorm
# ============================================================

def norm_layer(ch):
    return nn.GroupNorm(min(32, ch), ch)


# ============================================================
# Residual Block
# ============================================================

class ResBlock(nn.Module):

    def __init__(self, ch):
        super().__init__()
        self.norm1 = norm_layer(ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = norm_layer(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.conv1(self.act(self.norm1(x)))
        h = self.conv2(self.act(self.norm2(h)))
        return x + h


# ============================================================
# Down Block
# ============================================================

class DownBlock(nn.Module):

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 4, 2, 1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


# ============================================================
# Up Block
# ============================================================

class UpBlock(nn.Module):

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


# ============================================================
# VAE (CONSISTENT WITH YOUR LDM)
# ============================================================

class VAE(nn.Module):

    def __init__(self, im_channels=1, z_channels=16, base_ch=64):
        super().__init__()

        # ---------------- Encoder ----------------
        self.conv_in = nn.Conv2d(im_channels, base_ch, 3, padding=1)

        self.down1 = DownBlock(base_ch, base_ch * 2)
        self.res1  = ResBlock(base_ch * 2)

        self.down2 = DownBlock(base_ch * 2, base_ch * 4)
        self.res2  = ResBlock(base_ch * 4)

        self.down3 = DownBlock(base_ch * 4, base_ch * 4)
        self.res3  = ResBlock(base_ch * 4)

        mid_ch = base_ch * 4

        self.to_stats = nn.Conv2d(mid_ch, z_channels * 2, 3, padding=1)

        # ---------------- Decoder ----------------
        self.from_latent = nn.Conv2d(z_channels, mid_ch, 3, padding=1)

        self.res4 = ResBlock(mid_ch)

        self.up1 = UpBlock(mid_ch, base_ch * 4)
        self.res5 = ResBlock(base_ch * 4)

        self.up2 = UpBlock(base_ch * 4, base_ch * 2)
        self.res6 = ResBlock(base_ch * 2)

        self.up3 = UpBlock(base_ch * 2, base_ch)
        self.res7 = ResBlock(base_ch)

        self.norm_out = norm_layer(base_ch)
        self.conv_out = nn.Conv2d(base_ch, im_channels, 3, padding=1)


    # ============================================================
    # Encode
    # ============================================================

    def encode(self, x):
        x = self.conv_in(x)
        x = self.res1(self.down1(x))
        x = self.res2(self.down2(x))
        x = self.res3(self.down3(x))

        stats = self.to_stats(x)
        mean, logvar = torch.chunk(stats, 2, dim=1)

        logvar = torch.clamp(logvar, -10, 10)

        return mean, logvar


    # ============================================================
    # Reparameterization
    # ============================================================

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std


    # ============================================================
    # Decode
    # ============================================================

    def decode(self, z):
        x = self.from_latent(z)
        x = self.res4(x)

        x = self.res5(self.up1(x))
        x = self.res6(self.up2(x))
        x = self.res7(self.up3(x))

        x = self.conv_out(F.silu(self.norm_out(x)))

        return torch.tanh(x)


    # ============================================================
    # Forward
    # ============================================================

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        recon = self.decode(z)
        return recon, mean, logvar


# ============================================================
# TEST (512 → 64 → 512)
# ============================================================

if __name__ == "__main__":

    vae = VAE().to(device)

    x = torch.randn(1, 1, 512, 512).to(device)

    recon, mean, logvar = vae(x)

    print("Input:", x.shape)
    print("Latent:", mean.shape)   # should be [1, 16, 64, 64]
    print("Recon:", recon.shape)

In [None]:
from torch.utils.data import Dataset, DataLoader
import glob
import cv2
import os
import torch


# ============================================================
# ROOT PATH (MATCHES YOUR PREVIOUS CODE)
# ============================================================

DATA_ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"

TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR   = os.path.join(DATA_ROOT, "val")
TEST_DIR  = os.path.join(DATA_ROOT, "test")

if not os.path.exists(TRAIN_DIR):
    raise FileNotFoundError("Train directory not found")

if not os.path.exists(VAL_DIR):
    raise FileNotFoundError("Validation directory not found")

if not os.path.exists(TEST_DIR):
    raise FileNotFoundError("Test directory not found")

print("Dataset folders verified ✓")


# ============================================================
# DATASET
# ============================================================

class OCTDataset(Dataset):

    def __init__(self, root, image_size=512):

        self.paths = sorted(
            glob.glob(os.path.join(root, "**", "*.png"), recursive=True)
        )

        self.image_size = image_size

        if len(self.paths) == 0:
            raise RuntimeError(f"No images found in {root}")

        print(f"Loaded {len(self.paths)} images from {root}")


    def __len__(self):
        return len(self.paths)


    def __getitem__(self, idx):

        path = self.paths[idx]

        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)

        if img is None:
            raise RuntimeError(f"Failed to load image: {path}")

        # Fixed resolution (consistent with VAE symmetry)
        img = cv2.resize(
            img,
            (self.image_size, self.image_size),
            interpolation=cv2.INTER_AREA
        )

        img = torch.from_numpy(img).float().contiguous()

        img = img / 255.0

        img = img.unsqueeze(0) * 2 - 1   # [-1, 1]

        return {"image": img}


# ============================================================
# DATASETS
# ============================================================

train_ds = OCTDataset(TRAIN_DIR)
val_ds   = OCTDataset(VAL_DIR)
test_ds  = OCTDataset(TEST_DIR)


# ============================================================
# LOADER SETTINGS (MATCH PREVIOUS PIPELINE)
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 4   # Matches your previous LDM setup

loader_kwargs = dict(
    batch_size=BATCH_SIZE,
    num_workers=0,                 # Notebook-safe (as before)
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False
)

train_loader = DataLoader(train_ds, shuffle=True,  **loader_kwargs)
val_loader   = DataLoader(val_ds,   shuffle=False, **loader_kwargs)
test_loader  = DataLoader(test_ds,  shuffle=False, **loader_kwargs)

print("\nDataLoaders ready ✓")
print("Batch size:", BATCH_SIZE)
print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))


# ============================================================
# SANITY CHECK
# ============================================================

batch = next(iter(train_loader))

print("\nBatch image shape:", batch["image"].shape)

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
import lpips


# ============================================================
# PATHS (MATCHES YOUR PREVIOUS LOCAL SETUP)
# ============================================================

DATA_ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR   = os.path.join(DATA_ROOT, "val")

OUTPUT_ROOT = r"D:\Rushi_OCT_Diffusion\oct_ldm_output"
VAE_OUT_DIR = os.path.join(OUTPUT_ROOT, "vae_training")
PREVIEW_DIR = os.path.join(VAE_OUT_DIR, "previews")

os.makedirs(VAE_OUT_DIR, exist_ok=True)
os.makedirs(PREVIEW_DIR, exist_ok=True)


# ============================================================
# DEVICE
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


# ============================================================
# SAFE NORMALIZATION
# ============================================================

def normalize_batch(x):

    if x.max() > 1.5:
        x = x / 127.5 - 1
    elif x.min() >= 0:
        x = x * 2 - 1

    return x.clamp(-1, 1)


# ============================================================
# PERCEPTUAL LOSS (LPIPS)
# ============================================================

lpips_loss = lpips.LPIPS(net="vgg").to(device)
lpips_loss.eval()

for p in lpips_loss.parameters():
    p.requires_grad = False


def perceptual_loss(pred, target):

    pred3 = pred.repeat(1, 3, 1, 1)
    target3 = target.repeat(1, 3, 1, 1)

    return lpips_loss(pred3, target3).mean()


# ============================================================
# LOSSES
# ============================================================

def kl_loss(mean, logvar):

    logvar = torch.clamp(logvar, -30, 20)

    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1, 2, 3]
    ).mean()


def recon_loss(pred, target):
    return F.l1_loss(pred, target)


# ============================================================
# DISPLAY GRID
# ============================================================

def show_grid(tensor, title=""):

    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=4)
    grid = grid.cpu().permute(1, 2, 0).numpy()

    plt.figure(figsize=(6, 6))
    plt.imshow(grid.squeeze(), cmap="gray")
    plt.title(title)
    plt.axis("off")
    plt.show()


# ============================================================
# TRAIN FUNCTION
# ============================================================

def train_vae(vae, train_loader, val_loader, epochs=150, lr=1e-4):

    vae = vae.to(device)

    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)

    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    preview = normalize_batch(
        next(iter(val_loader))["image"][:4].to(device)
    )

    best_val = float("inf")

    for epoch in range(1, epochs + 1):

        torch.cuda.empty_cache()

        beta = min(1e-4, epoch * 5e-6)   # KL warmup
        perc_weight = 0.005

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader, desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:

            img = normalize_batch(batch["image"].to(device))

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):

                recon, mean, logvar = vae(img)

                loss_l1 = recon_loss(recon, img)
                loss_p  = perceptual_loss(recon, img)
                loss_k  = kl_loss(mean, logvar)

                loss = loss_l1 + perc_weight * loss_p + beta * loss_k

            if not torch.isfinite(loss):
                print("NaN detected — stopping")
                return

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_l1.item()

            pbar.set_postfix(
                L1=f"{loss_l1.item():.4f}",
                KL=f"{loss_k.item():.4f}",
                beta=f"{beta:.6f}"
            )

        total_recon /= len(train_loader)


        # ================= VALIDATION =================

        vae.eval()
        val_recon = 0

        with torch.no_grad():

            for batch in val_loader:

                img = normalize_batch(batch["image"].to(device))

                mean, logvar = vae.encode(img)

                std = torch.exp(0.5 * logvar)
                eps = torch.randn_like(std)
                z = mean + std * eps

                recon = vae.decode(z)

                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nEpoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")


        # ================= PREVIEW =================

        with torch.no_grad():

            mean, logvar = vae.encode(preview)

            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mean + std * eps

            recon = vae.decode(z)

        vis = torch.cat([preview, recon], dim=0)

        save_image(
            (vis + 1) / 2,
            os.path.join(PREVIEW_DIR, f"epoch_{epoch}.png"),
            nrow=4
        )

        show_grid(vis, title=f"Epoch {epoch}")


        # ================= SAVE BEST =================

        if val_recon < best_val:

            best_val = val_recon

            torch.save(
                vae.state_dict(),
                os.path.join(VAE_OUT_DIR, "vae_best.pth")
            )

            print("New best model saved")

    print("\nVAE Training Complete")


# ============================================================
# RUN
# ============================================================

vae = VAE(z_channels=16).to(device)
train_vae(vae, train_loader, val_loader)

In [None]:
import os, glob, cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler


ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"

TRAIN_PATH = os.path.join(ROOT, "train")
VAL_PATH   = os.path.join(ROOT, "val")
TEST_PATH  = os.path.join(ROOT, "test")

CHECKPOINT_DIR = r"D:\Rushi_OCT_Diffusion\OCT_Local_pths"
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, "ldm_best.pth")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "ldm_checkpoint.pth")
VAE_PATH = os.path.join(CHECKPOINT_DIR, "vae_best.pth")

PREVIEW_DIR = os.path.join(CHECKPOINT_DIR, "epoch_samples")
os.makedirs(PREVIEW_DIR, exist_ok=True)

if not os.path.exists(TRAIN_PATH):
    raise FileNotFoundError("Train folder not found")

if not os.path.exists(VAE_PATH):
    raise FileNotFoundError("VAE checkpoint not found")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ============================================================
# DATASET
# ============================================================

class OCTDataset(Dataset):

    def __init__(self, root, size=512):
        self.paths = sorted(
            glob.glob(os.path.join(root, "**/*.png"), recursive=True)
        )
        self.size = size
        print(f"{root} -> {len(self.paths)} images")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = cv2.imread(self.paths[i], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.size, self.size))
        img = torch.from_numpy(img).float() / 255.0
        img = img.unsqueeze(0) * 2 - 1
        return {"image": img}


train_loader = DataLoader(
    OCTDataset(TRAIN_PATH),
    batch_size=8,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    OCTDataset(VAL_PATH),
    batch_size=8,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)


# ============================================================
# VAE ARCHITECTURE (FOR LATENT ENCODING)
# ============================================================

def norm_layer(ch):
    return nn.GroupNorm(min(32, ch), ch)


class ResBlock(nn.Module):

    def __init__(self, ch):
        super().__init__()
        self.norm1 = norm_layer(ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.norm2 = norm_layer(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.conv1(self.act(self.norm1(x)))
        h = self.conv2(self.act(self.norm2(h)))
        return x + h


class DownBlock(nn.Module):

    def __init__(self, i, o):
        super().__init__()
        self.conv = nn.Conv2d(i, o, 4, 2, 1)
        self.norm = norm_layer(o)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class UpBlock(nn.Module):

    def __init__(self, i, o):
        super().__init__()
        self.conv = nn.ConvTranspose2d(i, o, 4, 2, 1)
        self.norm = norm_layer(o)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class VAE(nn.Module):

    def __init__(self):
        super().__init__()

        b = 64
        z = 16

        self.conv_in = nn.Conv2d(1, b, 3, 1, 1)

        self.down1 = DownBlock(b, b * 2)
        self.res1 = ResBlock(b * 2)

        self.down2 = DownBlock(b * 2, b * 4)
        self.res2 = ResBlock(b * 4)

        self.down3 = DownBlock(b * 4, b * 4)
        self.res3 = ResBlock(b * 4)

        self.to_stats = nn.Conv2d(b * 4, z * 2, 3, 1, 1)

        self.from_latent = nn.Conv2d(z, b * 4, 3, 1, 1)

        self.res4 = ResBlock(b * 4)

        self.up1 = UpBlock(b * 4, b * 4)
        self.res5 = ResBlock(b * 4)

        self.up2 = UpBlock(b * 4, b * 2)
        self.res6 = ResBlock(b * 2)

        self.up3 = UpBlock(b * 2, b)
        self.res7 = ResBlock(b)

        self.norm_out = norm_layer(b)
        self.conv_out = nn.Conv2d(b, 1, 3, 1, 1)

    def encode(self, x):
        x = self.conv_in(x)
        x = self.res1(self.down1(x))
        x = self.res2(self.down2(x))
        x = self.res3(self.down3(x))
        mean, logvar = torch.chunk(self.to_stats(x), 2, 1)
        return mean, logvar.clamp(-10, 10)

    def decode(self, z):
        x = self.from_latent(z)
        x = self.res4(x)
        x = self.res5(self.up1(x))
        x = self.res6(self.up2(x))
        x = self.res7(self.up3(x))
        return torch.tanh(self.conv_out(F.silu(self.norm_out(x))))


vae = VAE().to(device)
vae.load_state_dict(torch.load(VAE_PATH, map_location=device))
vae.eval().requires_grad_(False)


# ============================================================
# LDM (UNet + Scheduler)
# ============================================================

unet = UNet2DModel(
    sample_size=64,
    in_channels=16,
    out_channels=16,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D"
    )
).to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)
ddim = DDIMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

start_epoch = 1
end_epoch = 100
best_loss = float("inf")


if os.path.exists(CHECKPOINT_PATH):
    ck = torch.load(CHECKPOINT_PATH, map_location=device)
    unet.load_state_dict(ck["model"])
    optimizer.load_state_dict(ck["optimizer"])
    scaler.load_state_dict(ck["scaler"])
    best_loss = ck["best_loss"]
    start_epoch = ck["epoch"] + 1
    print(f"Resuming from epoch {ck['epoch']}")


@torch.no_grad()
def sample_images(epoch):

    unet.eval()
    ddim.set_timesteps(50)

    z = torch.randn(4, 16, 64, 64, device=device)

    for t in ddim.timesteps:
        noise_pred = unet(z, t).sample
        z = ddim.step(noise_pred, t, z).prev_sample

    imgs = vae.decode(z)
    imgs = (imgs + 1) / 2

    for i in range(4):
        save_image(
            imgs[i],
            os.path.join(PREVIEW_DIR, f"epoch_{epoch:03d}_img{i+1}.png")
        )


# ============================================================
# TRAINING LOOP
# ============================================================

for epoch in range(start_epoch, end_epoch + 1):

    unet.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}"):

        img = batch["image"].to(device)

        with torch.no_grad():
            mean, logvar = vae.encode(img)
            z = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)

        noise = torch.randn_like(z)
        t = torch.randint(0, 1000, (z.size(0),), device=device)

        xt = scheduler.add_noise(z, noise, t)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            pred = unet(xt, t).sample
            loss = F.mse_loss(pred, noise)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_train = total_loss / len(train_loader)
    print("Train Loss:", avg_train)


    # ================= VALIDATION =================

    unet.eval()
    val_total = 0

    with torch.no_grad():
        for batch in val_loader:

            img = batch["image"].to(device)
            mean, _ = vae.encode(img)
            z = mean

            noise = torch.randn_like(z)
            t = torch.randint(0, 1000, (z.size(0),), device=device)

            xt = scheduler.add_noise(z, noise, t)
            pred = unet(xt, t).sample

            val_total += F.mse_loss(pred, noise).item()

    avg_val = val_total / len(val_loader)
    print("Val Loss:", avg_val)

    sample_images(epoch)

    if avg_val < best_loss:
        best_loss = avg_val
        torch.save(unet.state_dict(), BEST_MODEL_PATH)

    torch.save({
        "epoch": epoch,
        "model": unet.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),
        "best_loss": best_loss
    }, CHECKPOINT_PATH)


print("Training Complete")

In [None]:
import os
import glob
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from diffusers import UNet2DModel, DDIMScheduler
from pytorch_msssim import ssim
import lpips


# =======================
# PATH CONFIGURATION
# =======================

DATA_ROOT = r"D:\Rushi_OCT_Diffusion\CCDS_Split_10K-20260120T043900Z-3-001\CCDS_Split_10K"
CHECKPOINT_DIR = r"D:\Rushi_OCT_Diffusion\OCT_Local_pths"

TEST_PATH = os.path.join(DATA_ROOT, "test")
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, "ldm_best.pth")
VAE_PATH = os.path.join(CHECKPOINT_DIR, "vae_best.pth")

EVAL_DIR = os.path.join(CHECKPOINT_DIR, "Evaluation_Folder")
os.makedirs(EVAL_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# =======================
# DATASET
# =======================

class OCTDataset(Dataset):

    def __init__(self, root, size=512):
        self.paths = sorted(
            glob.glob(os.path.join(root, "**/*.png"), recursive=True)
        )
        self.size = size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = cv2.imread(self.paths[i], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.size, self.size))
        img = torch.from_numpy(img).float() / 255.0
        img = img.unsqueeze(0) * 2 - 1
        return img


test_loader = DataLoader(
    OCTDataset(TEST_PATH),
    batch_size=8,
    shuffle=False
)


# =======================
# MODEL ARCHITECTURE
# =======================

def norm_layer(ch):
    return nn.GroupNorm(min(32, ch), ch)


class ResBlock(nn.Module):

    def __init__(self, ch):
        super().__init__()
        self.norm1 = norm_layer(ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.norm2 = norm_layer(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, 1, 1)
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.conv1(self.act(self.norm1(x)))
        h = self.conv2(self.act(self.norm2(h)))
        return x + h


class DownBlock(nn.Module):

    def __init__(self, i, o):
        super().__init__()
        self.conv = nn.Conv2d(i, o, 4, 2, 1)
        self.norm = norm_layer(o)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class UpBlock(nn.Module):

    def __init__(self, i, o):
        super().__init__()
        self.conv = nn.ConvTranspose2d(i, o, 4, 2, 1)
        self.norm = norm_layer(o)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class VAE(nn.Module):

    def __init__(self):
        super().__init__()

        b = 64
        z = 16

        self.conv_in = nn.Conv2d(1, b, 3, 1, 1)

        self.down1 = DownBlock(b, b * 2)
        self.res1 = ResBlock(b * 2)

        self.down2 = DownBlock(b * 2, b * 4)
        self.res2 = ResBlock(b * 4)

        self.down3 = DownBlock(b * 4, b * 4)
        self.res3 = ResBlock(b * 4)

        self.to_stats = nn.Conv2d(b * 4, z * 2, 3, 1, 1)

        self.from_latent = nn.Conv2d(z, b * 4, 3, 1, 1)

        self.res4 = ResBlock(b * 4)

        self.up1 = UpBlock(b * 4, b * 4)
        self.res5 = ResBlock(b * 4)

        self.up2 = UpBlock(b * 4, b * 2)
        self.res6 = ResBlock(b * 2)

        self.up3 = UpBlock(b * 2, b)
        self.res7 = ResBlock(b)

        self.norm_out = norm_layer(b)
        self.conv_out = nn.Conv2d(b, 1, 3, 1, 1)

    def encode(self, x):
        x = self.conv_in(x)
        x = self.res1(self.down1(x))
        x = self.res2(self.down2(x))
        x = self.res3(self.down3(x))
        m, l = torch.chunk(self.to_stats(x), 2, 1)
        return m, l

    def decode(self, z):
        x = self.from_latent(z)
        x = self.res4(x)
        x = self.res5(self.up1(x))
        x = self.res6(self.up2(x))
        x = self.res7(self.up3(x))
        return torch.tanh(self.conv_out(F.silu(self.norm_out(x))))


vae = VAE().to(device)
vae.load_state_dict(torch.load(VAE_PATH, map_location=device))
vae.eval()


unet = UNet2DModel(
    sample_size=64,
    in_channels=16,
    out_channels=16,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D"
    )
).to(device)

unet.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
unet.eval()

scheduler = DDIMScheduler(num_train_timesteps=1000)
scheduler.set_timesteps(50)

lpips_model = lpips.LPIPS(net="alex").to(device)


# =======================
# GENERATION
# =======================

@torch.no_grad()
def generate_samples(num):

    z = torch.randn(num, 16, 64, 64, device=device)

    for t in scheduler.timesteps:
        noise_pred = unet(z, t).sample
        z = scheduler.step(noise_pred, t, z).prev_sample

    imgs = vae.decode(z)
    return (imgs + 1) / 2


# =======================
# EVALUATION LOOP
# =======================

mse_total, psnr_total, ssim_total, lpips_total = 0, 0, 0, 0
count = 0

for batch in test_loader:

    real = (batch.to(device) + 1) / 2
    fake = generate_samples(real.size(0))

    mse = F.mse_loss(fake, real).item()
    psnr = float("inf") if mse == 0 else 10 * np.log10(1 / mse)
    ssim_val = ssim(fake, real, data_range=1).item()

    lp = lpips_model(
        fake.repeat(1, 3, 1, 1),
        real.repeat(1, 3, 1, 1)
    ).mean().item()

    mse_total += mse
    psnr_total += psnr
    ssim_total += ssim_val
    lpips_total += lp
    count += 1


mse_avg = mse_total / count
psnr_avg = psnr_total / count
ssim_avg = ssim_total / count
lpips_avg = lpips_total / count


with open(os.path.join(EVAL_DIR, "metrics.txt"), "w") as f:
    f.write(f"MSE: {mse_avg}\n")
    f.write(f"PSNR: {psnr_avg}\n")
    f.write(f"SSIM: {ssim_avg}\n")
    f.write(f"LPIPS: {lpips_avg}\n")


print("Evaluation Completed")
print("MSE:", mse_avg)
print("PSNR:", psnr_avg)
print("SSIM:", ssim_avg)
print("LPIPS:", lpips_avg)