# **Correct Code for Processing Images Min Max Algorithm**

In [None]:
import os
import cv2
import numpy as np

# Sir, base folder where notebook is running
BASE_DIR = os.getcwd()

# Sir, input dataset folder (put your raw png folders here)
PNG_ROOT = os.path.join(BASE_DIR, "raw_png")

# Sir, output processed folder
OUT_ROOT = os.path.join(BASE_DIR, "processed_png")

os.makedirs(OUT_ROOT, exist_ok=True)

print("Sir, input folder:", PNG_ROOT)
print("Sir, output folder:", OUT_ROOT)

# Resize + Pad

def pad_and_resize(img, target_size=512):
    """
    Sir, resize image keeping aspect ratio
    then pad using reflection to make square
    """

    H, W = img.shape

    scale = target_size / max(H, W)
    new_H, new_W = int(H * scale), int(W * scale)

    resized = cv2.resize(img, (new_W, new_H),
                         interpolation=cv2.INTER_AREA)

    pad_H = target_size - new_H
    pad_W = target_size - new_W

    top = pad_H // 2
    bottom = pad_H - top
    left = pad_W // 2
    right = pad_W - left

    padded = np.pad(resized,
                    ((top, bottom), (left, right)),
                    mode="reflect")

    return padded


# Normalize


def normalize_to_uint8(img):
    """
    Sir, normalize image to 0–255 uint8
    """

    img = img.astype(np.float32)

    min_val = img.min()
    max_val = img.max()

    if max_val - min_val > 1e-6:
        img = (img - min_val) / (max_val - min_val)
    else:
        img = np.zeros_like(img)

    img = (img * 255).astype(np.uint8)

    return img


# Processing Function

def process_png_root(PNG_ROOT, OUT_ROOT,
                     start_pid=0, end_pid=999999):

    if not os.path.isdir(PNG_ROOT):
        print("[ERROR] Sir, PNG folder not found:", PNG_ROOT)
        return

    patients = [
        p for p in os.listdir(PNG_ROOT)
        if p.isdigit() and start_pid <= int(p) <= end_pid
    ]

    patients = sorted(patients, key=lambda x: int(x))

    print(f"Sir, found {len(patients)} patient folders")

    total_frames = 0

    for pid in patients:
        in_dir = os.path.join(PNG_ROOT, pid)
        out_dir = os.path.join(OUT_ROOT, pid)
        os.makedirs(out_dir, exist_ok=True)

        frames = [f for f in os.listdir(in_dir)
                  if f.lower().endswith(".png")]

        if not frames:
            continue

        print(f"Sir, processing patient {pid}: {len(frames)} images")

        for frame in frames:
            img_path = os.path.join(in_dir, frame)

            img = cv2.imread(img_path,
                             cv2.IMREAD_GRAYSCALE)

            if img is None:
                print("[WARNING] Could not read:", img_path)
                continue

            img = pad_and_resize(img, 512)
            img = normalize_to_uint8(img)

            out_path = os.path.join(out_dir, frame)
            cv2.imwrite(out_path, img)

            total_frames += 1

    print("\n========== FINAL REPORT ==========")
    print("Patients processed:", len(patients))
    print("Total images saved:", total_frames)
    print("==================================\n")

# RUN PROCESSING

START_PATIENT_ID = 1001
END_PATIENT_ID = 7411

process_png_root(PNG_ROOT, OUT_ROOT,
                 START_PATIENT_ID,
                 END_PATIENT_ID)


# **WHole Dataset Split**

In [None]:
import os
import shutil
import random

# Sir, base folder where notebook runs
BASE_DIR = os.getcwd()

# Sir, input = processed images from previous step
DATA_ROOT = os.path.join(BASE_DIR, "processed_png")

# Sir, output dataset split folder
OUT_ROOT = os.path.join(BASE_DIR, "dataset")

os.makedirs(OUT_ROOT, exist_ok=True)

print("Sir, source folder:", DATA_ROOT)
print("Sir, dataset output:", OUT_ROOT)


def split_full_dataset(
    data_root,
    out_root,
    train_ratio=0.85,
    val_ratio=0.10,
    test_ratio=0.05,
    seed=42
):
    random.seed(seed)

    # -------------------------------
    # Scan patient folders
    # -------------------------------
    patients = []
    for pid in os.listdir(data_root):
        pid_path = os.path.join(data_root, pid)
        if os.path.isdir(pid_path):
            png_count = len([
                f for f in os.listdir(pid_path)
                if f.lower().endswith(".png")
            ])
            if png_count > 0:
                patients.append((pid, png_count))

    print(f"Sir, total patients found: {len(patients)}")

    random.shuffle(patients)

    # -------------------------------
    # Patient-level split
    # -------------------------------
    total_p = len(patients)

    n_train = int(total_p * train_ratio)
    n_val   = int(total_p * val_ratio)
    n_test  = total_p - n_train - n_val

    train_p = patients[:n_train]
    val_p   = patients[n_train:n_train + n_val]
    test_p  = patients[n_train + n_val:]

    print("\nSir, split summary:")
    print("Train patients:", len(train_p))
    print("Val patients:", len(val_p))
    print("Test patients:", len(test_p))
    print()

    # -------------------------------
    # Create split folders
    # -------------------------------
    for split in ["train", "val", "test"]:
        os.makedirs(os.path.join(out_root, split), exist_ok=True)

    # -------------------------------
    # Copy patient folders
    # -------------------------------
    for split_name, split_list in zip(
        ["train", "val", "test"],
        [train_p, val_p, test_p]
    ):
        split_path = os.path.join(out_root, split_name)

        for pid, _ in split_list:
            src = os.path.join(data_root, pid)
            dst = os.path.join(split_path, pid)

            if os.path.exists(dst):
                shutil.rmtree(dst)

            shutil.copytree(src, dst)

    # -------------------------------
    # Count images
    # -------------------------------
    print("Sir, image counts per split:")
    total_all = 0

    for split_name in ["train", "val", "test"]:
        split_path = os.path.join(out_root, split_name)
        count = 0

        for pid in os.listdir(split_path):
            pid_path = os.path.join(split_path, pid)
            if os.path.isdir(pid_path):
                count += len([
                    f for f in os.listdir(pid_path)
                    if f.lower().endswith(".png")
                ])

        print(f"{split_name.upper()}: {count}")
        total_all += count

    print("\nSir, TOTAL PNG FILES:", total_all)
    print("Sir, dataset split complete ✓\n")


# ============================================================
# RUN SPLIT
# ============================================================

split_full_dataset(DATA_ROOT, OUT_ROOT)


# **Environment Setup + Paths + Device + Helpers**

In [None]:
import os
import cv2
import glob
import math
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import torchvision.transforms as T


# ============================================================
# Sir, set random seed for reproducibility
# ============================================================

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

seed_everything(42)


# ============================================================
# DATASET PATHS (Notebook portable)
# ============================================================

BASE_DIR = os.getcwd()

# Sir, dataset folder should be placed inside notebook folder
ROOT = os.path.join(BASE_DIR, "dataset")

TRAIN_DIR = os.path.join(ROOT, "train")
VAL_DIR = os.path.join(ROOT, "val")
TEST_DIR = os.path.join(ROOT, "test")

print("Sir, train folder:", TRAIN_DIR)
print("Sir, val folder:", VAL_DIR)
print("Sir, test folder:", TEST_DIR)

if not os.path.exists(TRAIN_DIR):
    raise FileNotFoundError("Sir, train directory not found")
if not os.path.exists(VAL_DIR):
    raise FileNotFoundError("Sir, validation directory not found")
if not os.path.exists(TEST_DIR):
    raise FileNotFoundError("Sir, test directory not found")

print("Sir, dataset folders verified ✓")


# ============================================================
# OUTPUT FOLDERS
# ============================================================

OUT_ROOT = os.path.join(BASE_DIR, "oct_ldm_output")

DIR_SAMPLES = os.path.join(OUT_ROOT, "samples")
DIR_CHECKPOINTS = os.path.join(OUT_ROOT, "checkpoints")
DIR_PLOTS = os.path.join(OUT_ROOT, "plots")
DIR_JSON = os.path.join(OUT_ROOT, "json")
DIR_CSV = os.path.join(OUT_ROOT, "csv")

os.makedirs(DIR_SAMPLES, exist_ok=True)
os.makedirs(DIR_CHECKPOINTS, exist_ok=True)
os.makedirs(DIR_PLOTS, exist_ok=True)
os.makedirs(DIR_JSON, exist_ok=True)
os.makedirs(DIR_CSV, exist_ok=True)

print("Sir, output folders created at:", OUT_ROOT)


# ============================================================
# DEVICE
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if torch.cuda.is_available():
    print("Sir, GPU:", torch.cuda.get_device_name(0))


# ============================================================
# VISUALIZATION FUNCTION
# ============================================================

def show_grid(tensor, nrow=8, title=""):
    tensor = tensor.detach()

    grid = make_grid(tensor.clamp(-1, 1), nrow=nrow, normalize=False)
    grid = grid.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(8, 8))

    if grid.shape[-1] == 1:
        plt.imshow(grid[..., 0], cmap="gray")
    else:
        plt.imshow(grid)

    if title:
        plt.title(title)

    plt.axis("off")
    plt.tight_layout()
    plt.show()


print("Sir, local OCT LDM environment ready ✓")


# **Dataset Construction & DataLoader Interface**

In [None]:
import os
import glob
import cv2
import torch
from torch.utils.data import Dataset, DataLoader

# ============================================================
# DATASET CLASS
# ============================================================

class OCTDataset(Dataset):
    def __init__(self, root, image_size=512, recursive=True):
        self.root = root
        self.image_size = image_size

        if not os.path.exists(root):
            raise FileNotFoundError(f"Sir, dataset folder not found: {root}")

        pattern = "**/*.png" if recursive else "*.png"
        self.paths = sorted(glob.glob(os.path.join(root, pattern), recursive=recursive))

        if len(self.paths) == 0:
            raise RuntimeError(f"Sir, no PNG files found in {root}")

        print(f"Sir, loaded {len(self.paths)} OCT scans from {root}")

    def __len__(self):
        return len(self.paths)

    def _load_oct_image(self, path):
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise RuntimeError(f"Sir, failed to load image: {path}")
        return img

    def _pad_keep_aspect(self, img):
        h, w = img.shape
        target = self.image_size

        scale = min(target / h, target / w)
        new_h, new_w = int(h * scale), int(w * scale)

        img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)

        pad_h = target - new_h
        pad_w = target - new_w

        top = pad_h // 2
        bottom = pad_h - top
        left = pad_w // 2
        right = pad_w - left

        img_padded = cv2.copyMakeBorder(
            img_resized,
            top,
            bottom,
            left,
            right,
            borderType=cv2.BORDER_REFLECT_101
        )

        return img_padded

    def __getitem__(self, idx):
        path = self.paths[idx]

        img = self._load_oct_image(path)
        img = self._pad_keep_aspect(img)

        img = torch.from_numpy(img).float()
        img = img.unsqueeze(0)
        img = (img / 255.0) * 2.0 - 1.0

        patient = os.path.basename(os.path.dirname(path))

        return {
            "image": img,
            "patient": patient,
            "path": path
        }


# ============================================================
# BUILD LOADERS (connected to previous pipeline)
# ============================================================

BASE_DIR = os.getcwd()
DATA_ROOT = os.path.join(BASE_DIR, "dataset")

TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR = os.path.join(DATA_ROOT, "val")
TEST_DIR = os.path.join(DATA_ROOT, "test")

print("Sir, using dataset folder:", DATA_ROOT)

train_loader, val_loader, test_loader = None, None, None

def build_loaders(batch_size=4, num_workers=2, image_size=512):

    train_ds = OCTDataset(TRAIN_DIR, image_size=image_size)
    val_ds   = OCTDataset(VAL_DIR, image_size=image_size)
    test_ds  = OCTDataset(TEST_DIR, image_size=image_size)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False
    )

    return train_loader, val_loader, test_loader


# ============================================================
# CREATE LOADERS
# ============================================================

train_loader, val_loader, test_loader = build_loaders()

print("Sir, data loaders ready ✓")
print("Train batches:", len(train_loader))
print("Val batches:", len(val_loader))
print("Test batches:", len(test_loader))


# ============================================================
# QUICK SANITY TEST
# ============================================================

batch = next(iter(train_loader))
print("Sample batch image shape:", batch["image"].shape)
print("Sample patients:", batch["patient"][:4])


# **Patient-wise OCT Cohort Analysis**

In [None]:
import os
import glob
from collections import defaultdict
import cv2

# ============================================================
# LINK TO PREVIOUS PIPELINE
# ============================================================

BASE_DIR = os.getcwd()
DATA_ROOT = os.path.join(BASE_DIR, "dataset")

TRAIN_DIR = os.path.join(DATA_ROOT, "train")
VAL_DIR   = os.path.join(DATA_ROOT, "val")
TEST_DIR  = os.path.join(DATA_ROOT, "test")

print("Sir, inspecting dataset at:", DATA_ROOT)


# ============================================================
# INSPECTION FUNCTION
# ============================================================

def inspect_oct_dataset(root):
    if not os.path.exists(root):
        raise FileNotFoundError(f"Sir, directory not found: {root}")

    print(f"\nInspecting directory: {root}")

    all_paths = sorted(glob.glob(os.path.join(root, "**", "*.png"), recursive=True))
    total_images = len(all_paths)

    print(f"Total PNG images: {total_images}")

    patient_counts = defaultdict(int)

    for path in all_paths:
        patient = os.path.basename(os.path.dirname(path))
        patient_counts[patient] += 1

    num_patients = len(patient_counts)

    if num_patients > 0:
        counts = list(patient_counts.values())
        print(f"Patients: {num_patients}")
        print(f"Minimum images per patient: {min(counts)}")
        print(f"Maximum images per patient: {max(counts)}")
        print(f"Average images per patient: {sum(counts) / num_patients:.2f}")
    else:
        print("Sir, no patient folders detected")

    return num_patients, total_images


# ============================================================
# RUN INSPECTION
# ============================================================

print("\nTRAIN DATASET")
train_patients, train_images = inspect_oct_dataset(TRAIN_DIR)

print("\nVALIDATION DATASET")
val_patients, val_images = inspect_oct_dataset(VAL_DIR)

print("\nTEST DATASET")
test_patients, test_images = inspect_oct_dataset(TEST_DIR)


print("\nDATASET SUMMARY")
print(f"Train  -> {train_patients} patients, {train_images} images")
print(f"Val    -> {val_patients} patients, {val_images} images")
print(f"Test   -> {test_patients} patients, {test_images} images")


# ============================================================
# SAMPLE IMAGE CHECK
# ============================================================

sample_files = glob.glob(os.path.join(TRAIN_DIR, "**", "*.png"), recursive=True)

if sample_files:
    img = cv2.imread(sample_files[0], cv2.IMREAD_GRAYSCALE)

    if img is not None:
        print(f"\nSir, sample image resolution: {img.shape} (H, W)")
    else:
        print("\nSir, failed to load sample image")
else:
    print("\nSir, no PNG files found in training directory")


# **Variational Autoencoder (VAE) Architecture for OCT Image Reconstruction**

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image

# Sir, this sets the folder where the notebook is running
BASE_DIR = os.getcwd()

# Sir, we create an output folder to save images
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Sir, this checks if GPU is available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# Sir, this function creates safe normalization
def norm_layer(ch):
    return nn.GroupNorm(min(32, ch), ch)


# Sir, this is a residual block
# It helps the model learn better features
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.norm1 = norm_layer(ch)
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm2 = norm_layer(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.conv1(self.act(self.norm1(x)))
        h = self.conv2(self.act(self.norm2(h)))
        return x + h   # Sir, skip connection


# Sir, this block reduces image size (downsampling)
class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 4, 2, 1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


# Sir, this block increases image size (upsampling)
class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1)
        self.norm = norm_layer(out_ch)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


# Sir, this is the full Variational Autoencoder model
class VAE(nn.Module):
    def __init__(self, im_channels=1, z_channels=8, base_ch=64):
        super().__init__()

        # Sir, encoder part
        self.conv_in = nn.Conv2d(im_channels, base_ch, 3, padding=1)

        self.down1 = DownBlock(base_ch, base_ch * 2)
        self.res1  = ResBlock(base_ch * 2)

        self.down2 = DownBlock(base_ch * 2, base_ch * 4)
        self.res2  = ResBlock(base_ch * 4)

        self.down3 = DownBlock(base_ch * 4, base_ch * 4)
        self.res3  = ResBlock(base_ch * 4)

        mid_ch = base_ch * 4

        # Sir, latent mean and variance
        self.to_stats = nn.Conv2d(mid_ch, z_channels * 2, 3, padding=1)

        # Sir, decoder part
        self.from_latent = nn.Conv2d(z_channels, mid_ch, 3, padding=1)

        self.res4 = ResBlock(mid_ch)

        self.up1 = UpBlock(mid_ch, base_ch * 4)
        self.res5 = ResBlock(base_ch * 4)

        self.up2 = UpBlock(base_ch * 4, base_ch * 2)
        self.res6 = ResBlock(base_ch * 2)

        self.up3 = UpBlock(base_ch * 2, base_ch)
        self.res7 = ResBlock(base_ch)

        self.norm_out = norm_layer(base_ch)
        self.conv_out = nn.Conv2d(base_ch, im_channels, 3, padding=1)

    # Sir, encoding step
    def encode(self, x):
        x = self.conv_in(x)
        x = self.res1(self.down1(x))
        x = self.res2(self.down2(x))
        x = self.res3(self.down3(x))

        stats = self.to_stats(x)
        mean, logvar = torch.chunk(stats, 2, dim=1)

        # Sir, clamp prevents explosion
        logvar = torch.clamp(logvar, -10, 10)
        return mean, logvar

    # Sir, reparameterization trick
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    # Sir, decoding step
    def decode(self, z):
        x = self.from_latent(z)
        x = self.res4(x)
        x = self.res5(self.up1(x))
        x = self.res6(self.up2(x))
        x = self.res7(self.up3(x))
        x = self.conv_out(F.silu(self.norm_out(x)))
        return torch.tanh(x)

    # Sir, full forward pass
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        recon = self.decode(z)
        return recon, mean, logvar


# Sir, testing the model
vae = VAE().to(device)
vae.eval()

with torch.no_grad():
    x = torch.randn(1, 1, 512, 512).to(device)
    recon, mean, logvar = vae(x)

print("Input shape:", x.shape)
print("Latent shape:", mean.shape)
print("Reconstruction shape:", recon.shape)

# Sir, save output image
out_path = os.path.join(OUTPUT_DIR, "reconstruction.png")
save_image((recon + 1) / 2, out_path)
print("Saved image to:", out_path)


# **OCT Dataset Loader for Training, Validation, and Testing**

In [None]:
from torch.utils.data import Dataset, DataLoader
import glob
import cv2
import os
import torch


# Sir, this dataset class loads OCT images from folder
class OCTDataset(Dataset):
    def __init__(self, root, image_size=512):
        self.paths = sorted(glob.glob(os.path.join(root, "**/*.png"), recursive=True))
        self.image_size = image_size

        if len(self.paths) == 0:
            raise RuntimeError(f"Sir, no images found in {root}")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]

        # Sir, read grayscale image
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)

        if img is None:
            raise RuntimeError(f"Sir, failed to load image: {path}")

        # Sir, resize to fixed size (important for VAE)
        img = cv2.resize(img, (self.image_size, self.image_size),
                         interpolation=cv2.INTER_AREA)

        # Sir, convert to tensor and normalize to [-1,1]
        img = torch.from_numpy(img).float().contiguous()
        img = img / 255.0
        img = img.unsqueeze(0) * 2 - 1

        return {"image": img}

# LOCAL PATH SETUP

BASE_DIR = os.getcwd()

# Sir, dataset folder should be inside notebook folder
DATA_DIR = os.path.join(BASE_DIR, "dataset")

TRAIN_DIR = os.path.join(DATA_DIR, "train")
VAL_DIR   = os.path.join(DATA_DIR, "val")
TEST_DIR  = os.path.join(DATA_DIR, "test")

print("Sir, dataset folder:", DATA_DIR)


# LOAD DATASETS

train_ds = OCTDataset(TRAIN_DIR)
val_ds   = OCTDataset(VAL_DIR)
test_ds  = OCTDataset(TEST_DIR)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,
                          num_workers=0, pin_memory=True)

val_loader   = DataLoader(val_ds, batch_size=4,
                          num_workers=0, pin_memory=True)

test_loader  = DataLoader(test_ds, batch_size=4,
                          num_workers=0, pin_memory=True)

print("Sir, loaders ready ✓")
print("Train images:", len(train_ds))
print("Val images:", len(val_ds))
print("Test images:", len(test_ds))


# **VAE Training and Evaluation Pipeline for OCT Images**

In [None]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

# Sir, base working directory (notebook folder)
BASE_DIR = os.getcwd()

# Sir, create training output folder
OUT_DIR = os.path.join(BASE_DIR, "vae_outputs")
PREVIEW_DIR = os.path.join(OUT_DIR, "previews")

os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(PREVIEW_DIR, exist_ok=True)

print("Sir, outputs will be saved in:", OUT_DIR)

# Sir, device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# LOSSES

# Sir, KL divergence loss
def kl_loss(mean, logvar):
    return -0.5 * torch.sum(
        1 + logvar - mean.pow(2) - logvar.exp(),
        dim=[1, 2, 3]
    ).mean()

# Sir, reconstruction loss
def recon_loss(pred, target):
    return F.l1_loss(pred, target)

# VISUALIZATION

# Sir, display image grid
def show_grid(tensor, title=""):
    tensor = (tensor + 1) / 2
    tensor = torch.clamp(tensor, 0, 1)

    grid = make_grid(tensor, nrow=8)
    grid = grid.cpu().permute(1, 2, 0).numpy()

    plt.figure(figsize=(8, 8))
    plt.imshow(grid.squeeze(), cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# STOCHASTIC CHECK

@torch.no_grad()
def check_latent_stochasticity(vae, loader):
    vae.eval()

    batch = next(iter(loader))
    x = batch["image"].to(device)

    mean, logvar = vae.encode(x)

    z1 = vae.reparameterize(mean, logvar)
    z2 = vae.reparameterize(mean, logvar)

    diff = torch.mean(torch.abs(z1 - z2))
    print("Sir, latent stochastic difference:", diff.item())

# TRAIN FUNCTION

def train_vae(vae, train_loader, val_loader,
              epochs=70, lr=1e-4):

    vae = vae.to(device)

    ckpt_file = os.path.join(OUT_DIR, "resume_checkpoint.pt")
    best_model = os.path.join(OUT_DIR, "vae_best.pth")

    optimizer = torch.optim.AdamW(vae.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    start_epoch = 1
    best_val = float("inf")

    # Sir, resume training if checkpoint exists
    if os.path.exists(ckpt_file):
        ckpt = torch.load(ckpt_file, map_location=device)
        vae.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scaler.load_state_dict(ckpt["scaler"])
        start_epoch = ckpt["epoch"] + 1
        best_val = ckpt["best_val"]
        print("Sir, resuming from epoch", start_epoch)

    preview = next(iter(val_loader))["image"]
    preview = preview[:min(8, preview.size(0))].to(device)

    # TRAIN LOOP

    for epoch in range(start_epoch, epochs + 1):

        beta = 0.01 * min(1.0, epoch / 100)

        vae.train()
        total_recon = 0

        pbar = tqdm(train_loader,
                    desc=f"[TRAIN] Epoch {epoch}/{epochs}")

        for batch in pbar:
            img = batch["image"].to(device)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                recon, mean, logvar = vae(img)
                loss_r = recon_loss(recon, img)
                loss_k = kl_loss(mean, logvar)
                loss = 10 * loss_r + beta * loss_k

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_recon += loss_r.item()

            pbar.set_postfix(L1=f"{loss_r.item():.4f}",
                             KL=f"{loss_k.item():.4f}")

        total_recon /= len(train_loader)

        #  VALIDATION

        vae.eval()
        val_recon = 0

        with torch.no_grad():
            for batch in val_loader:
                img = batch["image"].to(device)

                mean, logvar = vae.encode(img)
                recon = vae.decode(mean)

                val_recon += recon_loss(recon, img).item()

        val_recon /= len(val_loader)

        print(f"\nSir, Epoch {epoch}")
        print(f"Train L1: {total_recon:.4f} | Val L1: {val_recon:.4f}")

        #  PREVIEW SAVE

        with torch.no_grad():
            mean, logvar = vae.encode(preview)
            recon = vae.decode(mean)

        vis = torch.cat([preview, recon], dim=0)

        save_path = os.path.join(PREVIEW_DIR,
                                 f"epoch_{epoch}.png")
        save_image((vis + 1) / 2, save_path, nrow=8)

        show_grid(vis, title=f"Epoch {epoch}")

        # SAVE MODEL

        if val_recon < best_val:
            best_val = val_recon
            torch.save(vae.state_dict(), best_model)
            print("Sir, new best model saved")

        torch.save({
            "epoch": epoch,
            "model": vae.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict(),
            "best_val": best_val
        }, ckpt_file)

        print("Sir, checkpoint saved")

    print("\nSir, training complete")

# TEST FUNCTION

@torch.no_grad()
def test_vae(vae, test_loader):

    best_model = os.path.join(OUT_DIR, "vae_best.pth")

    vae.load_state_dict(torch.load(best_model,
                                   map_location=device))
    vae.eval()

    batch = next(iter(test_loader))
    img = batch["image"][:8].to(device)

    mean, logvar = vae.encode(img)
    recon = vae.decode(mean)

    show_grid(torch.cat([img, recon], dim=0),
              title="Test Reconstructions")

# RUN

vae = VAE(z_channels=8).to(device)

train_vae(vae, train_loader, val_loader)

check_latent_stochasticity(vae, train_loader)

test_vae(vae, test_loader)
