13=02-2026

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# LOSSES
# ============================================================

def kl_loss(mean, logvar):
    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1, 2, 3]
    ).mean()

def recon_loss(pred, target):
    return F.l1_loss(pred, target)

# ============================================================
# VISUALIZATION
# ============================================================

def show_grid(tensor, title=""):
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=8)
    grid = grid.cpu().permute(1, 2, 0).numpy()

    plt.figure(figsize=(8, 8))
    plt.imshow(grid.squeeze(), cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# ============================================================
# STOCHASTICITY CHECK
# ============================================================

@torch.no_grad()
def check_latent_stochasticity(vae, loader):
    vae.eval()
    batch = next(iter(loader))
    x = batch["image"].to(device)

    mean, logvar = vae.encode(x)

    z1 = vae.reparameterize(mean, logvar)
    z2 = vae.reparameterize(mean, logvar)

    diff = torch.mean(torch.abs(z1 - z2))
    print("Latent stochastic difference:", diff.item())

# ============================================================
# TRAIN
# ============================================================

def train_vae(vae, train_loader, val_loader,
              epochs=150, lr=1e-4,
              out_dir="./vae_outputs"):

    os.makedirs(out_dir, exist_ok=True)
    preview_dir = os.path.join(out_dir, "previews")
    os.makedirs(preview_dir, exist_ok=True)

    vae = vae.to(device)

    ckpt_file = os.path.join(out_dir, "resume_checkpoint.pt")
    best_model = os.path.join(out_dir, "vae_best.pth")

    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    start_epoch = 1
    best_val = float("inf")

    # Resume training
    if os.path.exists(ckpt_file):
        ckpt = torch.load(ckpt_file, map_location=device)
        vae.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        if "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])
        start_epoch = ckpt["epoch"] + 1
        best_val = ckpt["best_val"]
        print("Resuming from epoch", start_epoch)

    preview = next(iter(val_loader))["image"]
    preview = preview[:min(8, preview.size(0))].to(device)

    # ========================================================
    # TRAIN LOOP
    # ========================================================

    for epoch in range(start_epoch, epochs + 1):

        # KL warmup
        beta = 0.01 * min(1.0, epoch / 100)

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader, desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:
            img = batch["image"].to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                recon, mean, logvar = vae(img)
                loss_r = recon_loss(recon, img)
                loss_k = kl_loss(mean, logvar)
                loss = 10 * loss_r + beta * loss_k

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_r.item()

            pbar.set_postfix(L1=f"{loss_r.item():.4f}",
                             KL=f"{loss_k.item():.4f}",
                             beta=f"{beta:.5f}")

        total_recon /= len(train_loader)

        # ================= VALIDATION (deterministic) =================

        vae.eval()
        val_recon = 0

        with torch.no_grad():
            for batch in val_loader:
                img = batch["image"].to(device)

                mean, logvar = vae.encode(img)
                recon = vae.decode(mean)

                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nEpoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")

        # ================= PREVIEW (deterministic) =================

        with torch.no_grad():
            mean, logvar = vae.encode(preview)
            recon = vae.decode(mean)

        vis = torch.cat([preview, recon], dim=0)

        save_path = os.path.join(preview_dir, f"epoch_{epoch}.png")
        save_image((vis + 1) / 2, save_path, nrow=8)

        show_grid(vis, title=f"Epoch {epoch}")

        # ================= SAVE =================

        if val_recon < best_val:
            best_val = val_recon
            torch.save(vae.state_dict(), best_model)
            print("New best model saved")

        torch.save({
            "epoch": epoch,
            "model": vae.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict(),
            "best_val": best_val
        }, ckpt_file)

        print("Checkpoint saved")

    print("\nTraining complete")

# ============================================================
# TEST
# ============================================================

@torch.no_grad()
def test_vae(vae, test_loader,
             ckpt_path="./vae_outputs/vae_best.pth"):

    vae.load_state_dict(torch.load(ckpt_path, map_location=device))
    vae.eval()

    batch = next(iter(test_loader))
    img = batch["image"][:min(8, batch["image"].size(0))].to(device)

    mean, logvar = vae.encode(img)
    recon = vae.decode(mean)

    show_grid(torch.cat([img, recon], dim=0),
              title="Test Reconstructions")

# ============================================================
# RUN
# ============================================================

vae = VAE(z_channels=8).to(device)

train_vae(vae, train_loader, val_loader)

check_latent_stochasticity(vae, train_loader)

test_vae(vae, test_loader)


14-02-2026

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
import lpips

# ============================================================
# DEVICE
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# SAFE NORMALIZATION
# ============================================================

def normalize_batch(x):
    if x.max() > 1.5:
        x = x / 127.5 - 1
    elif x.min() >= 0:
        x = x * 2 - 1
    return x.clamp(-1, 1)

# ============================================================
# PERCEPTUAL LOSS
# ============================================================

lpips_loss = lpips.LPIPS(net='vgg').to(device)
lpips_loss.eval()

for p in lpips_loss.parameters():
    p.requires_grad = False

def perceptual_loss(pred, target):
    pred3 = pred.repeat(1,3,1,1)
    target3 = target.repeat(1,3,1,1)
    return lpips_loss(pred3, target3).mean()

# ============================================================
# LOSSES
# ============================================================

def kl_loss(mean, logvar):
    logvar = torch.clamp(logvar, -30, 20)
    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1,2,3]
    ).mean()

def recon_loss(pred, target):
    return F.l1_loss(pred, target)

# ============================================================
# PREVIEW STRETCH
# ============================================================

def stretch(x):
    x = x - x.min()
    x = x / (x.max() + 1e-8)
    return x * 2 - 1

# ============================================================
# DISPLAY GRID
# ============================================================

def show_grid(tensor, title=""):
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=4)
    grid = grid.cpu().permute(1,2,0).numpy()

    plt.figure(figsize=(6,6))
    plt.imshow(grid.squeeze(), cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# ============================================================
# TRAIN FUNCTION
# ============================================================

def train_vae(vae, train_loader, val_loader,
              epochs=150, lr=1e-4,
              out_dir="./vae_outputs"):

    os.makedirs(out_dir, exist_ok=True)
    preview_dir = os.path.join(out_dir, "previews")
    os.makedirs(preview_dir, exist_ok=True)

    vae = vae.to(device)
    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)
    scaler = torch.amp.GradScaler("cuda", enabled=(device.type=="cuda"))

    preview = normalize_batch(
        next(iter(val_loader))["image"][:4].to(device)
    )

    best_val = float("inf")

    for epoch in range(1, epochs+1):

        torch.cuda.empty_cache()

        beta = 1e-4

        # ✅ perceptual gentle warmup
        perc_weight = min(0.02, epoch * 0.005)

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader,
                    desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:
            img = normalize_batch(batch["image"].to(device))
            optimizer.zero_grad()

            with torch.amp.autocast("cuda",
                                    enabled=(device.type=="cuda")):

                recon, mean, logvar = vae(img)

                loss_l1 = recon_loss(recon, img)
                loss_p  = perceptual_loss(recon, img)
                loss_k  = kl_loss(mean, logvar)

                loss = loss_l1 + perc_weight*loss_p + beta*loss_k

            if not torch.isfinite(loss):
                print("NaN detected — stopping")
                return

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_l1.item()

            pbar.set_postfix(
                L1=f"{loss_l1.item():.4f}",
                KL=f"{loss_k.item():.4f}",
                pw=f"{perc_weight:.3f}"
            )

        total_recon /= len(train_loader)

        # ===== VALIDATION =====

        vae.eval()
        val_recon = 0

        with torch.no_grad():
            for batch in val_loader:
                img = normalize_batch(batch["image"].to(device))
                mean, logvar = vae.encode(img)
                recon = vae.decode(mean)
                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nEpoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")

        # ===== PREVIEW =====

        with torch.no_grad():
            mean, logvar = vae.encode(preview)
            recon = vae.decode(mean)

        vis = torch.cat([preview, stretch(recon)], dim=0)

        save_image((vis+1)/2,
                   os.path.join(preview_dir,
                   f"epoch_{epoch}.png"),
                   nrow=4)

        show_grid(vis, title=f"Epoch {epoch}")

        # ===== SAVE =====

        if val_recon < best_val:
            best_val = val_recon
            torch.save(vae.state_dict(),
                       os.path.join(out_dir,
                       "vae_best.pth"))
            print("New best model saved")

    print("\nTraining complete")

# ============================================================
# RUN
# ============================================================

vae = VAE(z_channels=16).to(device)
train_vae(vae, train_loader, val_loader)
