13=02-2026

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# LOSSES
# ============================================================

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1, 2, 3]
    ).mean()

def recon_loss(pred, target):
    return F.l1_loss(pred, target)

# ============================================================
# VISUALIZATION
# ============================================================

def show_grid(tensor, title=""):
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=8)
    grid = grid.cpu().permute(1, 2, 0).numpy()

    plt.figure(figsize=(8, 8))
    plt.imshow(grid.squeeze(), cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# ============================================================
# STOCHASTICITY CHECK
# ============================================================

@torch.no_grad()
def check_latent_stochasticity(vae, loader):
    vae.eval()
    batch = next(iter(loader))
    x = batch["image"].to(device)

    mean, logvar = vae.encode(x)

    z1 = vae.reparameterize(mean, logvar)
    z2 = vae.reparameterize(mean, logvar)

    diff = torch.mean(torch.abs(z1 - z2))
    print("Latent stochastic difference:", diff.item())

# ============================================================
# TRAIN
# ============================================================

def train_vae(vae, train_loader, val_loader,
              epochs=150, lr=1e-4,
              out_dir="./vae_outputs"):

    os.makedirs(out_dir, exist_ok=True)
    preview_dir = os.path.join(out_dir, "previews")
    os.makedirs(preview_dir, exist_ok=True)

    vae = vae.to(device)

    ckpt_file = os.path.join(out_dir, "resume_checkpoint.pt")
    best_model = os.path.join(out_dir, "vae_best.pth")

    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    start_epoch = 1
    best_val = float("inf")

    # Resume training
    if os.path.exists(ckpt_file):
        ckpt = torch.load(ckpt_file, map_location=device)
        vae.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])
        start_epoch = ckpt["epoch"] + 1
        best_val = ckpt["best_val"]
        print("Resuming from epoch", start_epoch)

    preview = next(iter(val_loader))["image"]
    preview = preview[:min(8, preview.size(0))].to(device)

    # ========================================================
    # TRAIN LOOP
    # ========================================================

    for epoch in range(start_epoch, epochs + 1):

        # KL warmup
        beta = 0.01 * min(1.0, epoch / 100)

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader, desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:
            img = batch["image"].to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                recon, mean, logvar = vae(img)
                loss_r = recon_loss(recon, img)
                loss_k = kl_loss(mean, logvar)
                loss = 10 * loss_r + beta * loss_k

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_r.item()

            pbar.set_postfix(L1=f"{loss_r.item():.4f}",
                             KL=f"{loss_k.item():.4f}",
                             beta=f"{beta:.5f}")

        total_recon /= len(train_loader)

        # ================= VALIDATION (deterministic) =================

        vae.eval()
        val_recon = 0

        with torch.no_grad():
            for batch in val_loader:
                img = batch["image"].to(device)

                mean, logvar = vae.encode(img)
                recon = vae.decode(mean)

                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nEpoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")

        # ================= PREVIEW (deterministic) =================

        with torch.no_grad():
            mean, logvar = vae.encode(preview)
            recon = vae.decode(mean)

        vis = torch.cat([preview, recon], dim=0)

        save_path = os.path.join(preview_dir, f"epoch_{epoch}.png")
        save_image((vis + 1) / 2, save_path, nrow=8)

        show_grid(vis, title=f"Epoch {epoch}")

        # ================= SAVE =================

        if val_recon < best_val:
            best_val = val_recon
            torch.save(vae.state_dict(), best_model)
            print("New best model saved")

        torch.save({
            "epoch": epoch,
            "model": vae.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict(),
            "best_val": best_val
        }, ckpt_file)

        print("Checkpoint saved")

    print("\nTraining complete")

# ============================================================
# TEST
# ============================================================

@torch.no_grad()
def test_vae(vae, test_loader,
             ckpt_path="./vae_outputs/vae_best.pth"):

    vae.load_state_dict(torch.load(ckpt_path, map_location=device))
    vae.eval()

    batch = next(iter(test_loader))
    img = batch["image"][:min(8, batch["image"].size(0))].to(device)

    mean, logvar = vae.encode(img)
    recon = vae.decode(mean)

    show_grid(torch.cat([img, recon], dim=0),
              title="Test Reconstructions")

# ============================================================
# RUN
# ============================================================

vae = VAE(z_channels=8).to(device)

train_vae(vae, train_loader, val_loader)

check_latent_stochasticity(vae, train_loader)

test_vae(vae, test_loader)


14-02-2026

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
import lpips

# ============================================================
# DEVICE
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# SAFE NORMALIZATION
# ============================================================

def normalize_batch(x):
    if x.max() > 1.5:
        x = x / 127.5 - 1
    elif x.min() >= 0:
        x = x * 2 - 1
    return x.clamp(-1, 1)

# ============================================================
# PERCEPTUAL LOSS
# ============================================================

lpips_loss = lpips.LPIPS(net='vgg').to(device)
lpips_loss.eval()

for p in lpips_loss.parameters():
    p.requires_grad = False

def perceptual_loss(pred, target):
    pred3 = pred.repeat(1,3,1,1)
    target3 = target.repeat(1,3,1,1)
    return lpips_loss(pred3, target3).mean()

# ============================================================
# LOSSES
# ============================================================

def kl_loss(mean, logvar):
    logvar = torch.clamp(logvar, -30, 20)
    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1,2,3]
    ).mean()

def recon_loss(pred, target):
    return F.l1_loss(pred, target)

# ============================================================
# PREVIEW STRETCH
# ============================================================

def stretch(x):
    x = x - x.min()
    x = x / (x.max() + 1e-8)
    return x * 2 - 1

# ============================================================
# DISPLAY GRID
# ============================================================

def show_grid(tensor, title=""):
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=4)
    grid = grid.cpu().permute(1,2,0).numpy()

    plt.figure(figsize=(6,6))
    plt.imshow(grid.squeeze(), cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()


# TRAIN FUNCTION


def train_vae(vae, train_loader, val_loader,
              epochs=150, lr=1e-4,
              out_dir="./vae_outputs"):

    os.makedirs(out_dir, exist_ok=True)
    preview_dir = os.path.join(out_dir, "previews")
    os.makedirs(preview_dir, exist_ok=True)

    vae = vae.to(device)
    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)
    scaler = torch.amp.GradScaler("cuda", enabled=(device.type=="cuda"))

    preview = normalize_batch(
        next(iter(val_loader))["image"][:4].to(device)
    )

    best_val = float("inf")

    for epoch in range(1, epochs+1):

        torch.cuda.empty_cache()

        beta = 1e-4

        # ✅ perceptual gentle warmup
        perc_weight = min(0.02, epoch * 0.005)

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader,
                    desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:
            img = normalize_batch(batch["image"].to(device))
            optimizer.zero_grad()

            with torch.amp.autocast("cuda",
                                    enabled=(device.type=="cuda")):

                recon, mean, logvar = vae(img)

                loss_l1 = recon_loss(recon, img)
                loss_p  = perceptual_loss(recon, img)
                loss_k  = kl_loss(mean, logvar)

                loss = loss_l1 + perc_weight*loss_p + beta*loss_k

            if not torch.isfinite(loss):
                print("NaN detected — stopping")
                return

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_l1.item()

            pbar.set_postfix(
                L1=f"{loss_l1.item():.4f}",
                KL=f"{loss_k.item():.4f}",
                pw=f"{perc_weight:.3f}"
            )

        total_recon /= len(train_loader)

        # ===== VALIDATION =====

        vae.eval()
        val_recon = 0

        with torch.no_grad():
            for batch in val_loader:
                img = normalize_batch(batch["image"].to(device))
                mean, logvar = vae.encode(img)
                recon = vae.decode(mean)
                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nEpoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")

        # ===== PREVIEW =====

        with torch.no_grad():
            mean, logvar = vae.encode(preview)
            recon = vae.decode(mean)

        vis = torch.cat([preview, stretch(recon)], dim=0)

        save_image((vis+1)/2,
                   os.path.join(preview_dir,
                   f"epoch_{epoch}.png"),
                   nrow=4)

        show_grid(vis, title=f"Epoch {epoch}")

        # ===== SAVE =====

        if val_recon < best_val:
            best_val = val_recon
            torch.save(vae.state_dict(),
                       os.path.join(out_dir,
                       "vae_best.pth"))
            print("New best model saved")

    print("\nTraining complete")

# ============================================================
# RUN


vae = VAE(z_channels=16).to(device)
train_vae(vae, train_loader, val_loader)


17-02-26

In [None]:
import os, glob, cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

ROOT = "/kaggle/input/datasets/rushikannan05/ccds-10k-daataset/CCDS_Split_10K"

TRAIN_PATH = os.path.join(ROOT, "train")
VAL_PATH   = os.path.join(ROOT, "val")
TEST_PATH  = os.path.join(ROOT, "test")

VAE_PATH = "/kaggle/input/models/rushikannan05/vae-best/pytorch/default/1/vae_best.pth"
INPUT_CHECKPOINT = "/kaggle/input/models/rushikannan05/ldm-checkpoint/pytorch/default/1/ldm_checkpoint.pth"

OUT_DIR = "/kaggle/working/LDM"
PREVIEW_DIR = os.path.join(OUT_DIR, "previews")
BEST_MODEL_PATH = os.path.join(OUT_DIR, "ldm_best.pth")
CHECKPOINT_PATH = os.path.join(OUT_DIR, "ldm_checkpoint.pth")

os.makedirs(PREVIEW_DIR, exist_ok=True)

class OCTDataset(Dataset):
    def __init__(self, root, size=512):
        self.paths = sorted(glob.glob(os.path.join(root, "**/*.png"), recursive=True))
        self.size = size
        print(f"{root} → {len(self.paths)} images")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = cv2.imread(self.paths[i], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.size, self.size))
        img = torch.from_numpy(img).float()/255
        img = img.unsqueeze(0)*2 - 1
        return {"image": img}

train_loader = DataLoader(OCTDataset(TRAIN_PATH), batch_size=4, shuffle=True, num_workers=2)
val_loader   = DataLoader(OCTDataset(VAL_PATH),   batch_size=4, shuffle=False, num_workers=2)
test_loader  = DataLoader(OCTDataset(TEST_PATH),  batch_size=4, shuffle=False, num_workers=2)

def norm_layer(ch):
    return nn.GroupNorm(min(32, ch), ch)

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.norm1 = norm_layer(ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = norm_layer(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.conv1(self.act(self.norm1(x)))
        h = self.conv2(self.act(self.norm2(h)))
        return x + h

class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        b=64; z=16

        self.conv_in = nn.Conv2d(1,b,3,1,1)
        self.down1 = DownBlock(b,b*2); self.res1 = ResBlock(b*2)
        self.down2 = DownBlock(b*2,b*4); self.res2 = ResBlock(b*4)
        self.down3 = DownBlock(b*4,b*4); self.res3 = ResBlock(b*4)

        self.to_stats = nn.Conv2d(b*4,z*2,3,1,1)
        self.from_latent = nn.Conv2d(z,b*4,3,1,1)

        self.res4 = ResBlock(b*4)
        self.up1  = UpBlock(b*4,b*4)
        self.res5 = ResBlock(b*4)
        self.up2  = UpBlock(b*4,b*2)
        self.res6 = ResBlock(b*2)
        self.up3  = UpBlock(b*2,b)
        self.res7 = ResBlock(b)

        self.norm_out = norm_layer(b)
        self.conv_out = nn.Conv2d(b,1,3,1,1)

    def encode(self,x):
        x=self.conv_in(x)
        x=self.res1(self.down1(x))
        x=self.res2(self.down2(x))
        x=self.res3(self.down3(x))
        m,l=torch.chunk(self.to_stats(x),2,1)
        return m,l.clamp(-10,10)

    def decode(self,z):
        x=self.from_latent(z)
        x=self.res4(x)
        x=self.res5(self.up1(x))
        x=self.res6(self.up2(x))
        x=self.res7(self.up3(x))
        return torch.tanh(self.conv_out(F.silu(self.norm_out(x))))

vae = VAE().to(device)
vae.load_state_dict(torch.load(VAE_PATH,map_location=device))
vae.eval().requires_grad_(False)

unet = UNet2DModel(
    sample_size=64,
    in_channels=16,
    out_channels=16,
    layers_per_block=2,
    block_out_channels=(128,256,512,512),
    down_block_types=("DownBlock2D","AttnDownBlock2D","AttnDownBlock2D","AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D","AttnUpBlock2D","AttnUpBlock2D","UpBlock2D")
).to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)
opt = torch.optim.AdamW(unet.parameters(), lr=1e-4)
scaler = torch.amp.GradScaler(device.type)

best=float("inf")
start=1

if os.path.exists(INPUT_CHECKPOINT):
    ck=torch.load(INPUT_CHECKPOINT,map_location=device)
    unet.load_state_dict(ck["model"])
    opt.load_state_dict(ck["optimizer"])
    scaler.load_state_dict(ck["scaler"])
    best=ck["best_loss"]
    start=ck["epoch"]+1
    print("Resuming at epoch",start)

@torch.no_grad()
def evaluate(loader,name):
    unet.eval(); total=0
    for b in loader:
        img=b["image"].to(device)
        m,l=vae.encode(img)
        z=m+torch.randn_like(m)*torch.exp(0.5*l)
        n=torch.randn_like(z)
        t=torch.randint(0,1000,(z.size(0),),device=device)
        xt=scheduler.add_noise(z,n,t)
        pred=unet(xt,t).sample
        total+=F.mse_loss(pred,n).item()
    avg=total/len(loader)
    print(f"{name} Loss: {avg:.6f}")
    return avg

for e in range(start,51):

    total=0; unet.train()

    for b in tqdm(train_loader,desc=f"Epoch {e}"):

        img=b["image"].to(device)

        with torch.no_grad():
            m,l=vae.encode(img)
            z=m+torch.randn_like(m)*torch.exp(0.5*l)

        n=torch.randn_like(z)
        t=torch.randint(0,1000,(z.size(0),),device=device)
        xt=scheduler.add_noise(z,n,t)

        with torch.amp.autocast(device.type):
            pred=unet(xt,t).sample
            loss=F.mse_loss(pred,n)

        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        total+=loss.item()

    train_loss=total/len(train_loader)
    print(f"Train Loss: {train_loss:.6f}")

    val_loss=evaluate(val_loader,"VAL")

    if val_loss<best:
        best=val_loss
        torch.save(unet.state_dict(),BEST_MODEL_PATH)
        print("⭐ Best model saved")

    torch.save({
        "epoch":e,
        "model":unet.state_dict(),
        "optimizer":opt.state_dict(),
        "scaler":scaler.state_dict(),
        "best_loss":best
    },CHECKPOINT_PATH)

print("\nFinal Test:")
evaluate(test_loader,"TEST")

print("✅ Training complete")
