# VAE Image Enhancement (Degradation -> Enhance)

This notebook trains a small convolutional VAE to 
 images: the model takes a degraded (blurred / downscaled-upscaled) image and reconstructs a higher-quality version at the same resolution.

Key parts:
- Degradation pipeline: random Gaussian blur + downscale/upscale to mimic low quality.
- Dataset returns (low_quality_input, high_quality_target) normalized to [-1, 1].
- Simple Conv VAE (encoder -> mu/logvar -> reparam -> decoder).
- Loss = L1(recon, target) + beta * KL. For enhancement, beta is small so reconstruction dominates.

Edit `DATA_DIR`, `OUT_DIR` and training params below to your needs. Run the cells in order.

In [12]:
# Cell 1: imports and paths
import os
from pathlib import Path
import random
import math
import time
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import save_image, make_grid

# Paths (edit these)
ROOT = Path('.')
DATA_DIR = ROOT / '../data' / 'Face-Swap-M2-Dataset' / 'dataset' / 'smaller'  # adjust to your dataset
OUT_DIR = ROOT / 'outputs' / 'vae_enhance'
OUT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR = OUT_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLES_DIR = OUT_DIR / 'samples'
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import dataloader
print('Device:', device)
print('DATA_DIR:', DATA_DIR)

Device: cuda
DATA_DIR: ../data/Face-Swap-M2-Dataset/dataset/smaller


In [13]:
# Cell 2: dataset + degradation utilities
class DegradeTransforms:
    def __init__(self, min_scale=0.5, max_scale=0.9, blur_prob=0.5, max_blur_ks=7):
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.blur_prob = blur_prob
        self.max_blur_ks = max_blur_ks

    def random_down_up(self, img, scale):
        # img: HxWxC uint8 RGB
        h, w = img.shape[:2]
        new_h = max(2, int(h * scale))
        new_w = max(2, int(w * scale))
        # downscale then upscale
        small = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
        up = cv2.resize(small, (w, h), interpolation=cv2.INTER_LINEAR)
        return up

    def random_blur(self, img):
        # random odd kernel size
        k = random.randrange(1, self.max_blur_ks, 2) if random.random() < self.blur_prob else 1
        if k > 1:
            return cv2.GaussianBlur(img, (k, k), 0)
        return img

    def degrade(self, img):
        # img: HxWxC uint8 RGB
        scale = random.uniform(self.min_scale, self.max_scale)
        out = self.random_down_up(img, scale)
        out = self.random_blur(out)
        return out

class VaeImageDataset(Dataset):
    def __init__(self, root_dir, size=256, transform=None, degrade=None, extensions=['.jpg', '.png', '.jpeg']):
        self.root_dir = Path(root_dir)
        self.files = [p for p in self.root_dir.rglob('*') if p.suffix.lower() in extensions]
        self.size = size
        self.transform = transform
        self.degrade = degrade
        if len(self.files) == 0:
            raise RuntimeError(f'No images found in {root_dir}')

    def __len__(self):
        return len(self.files)

    def load_image(self, p):
        im = cv2.imread(str(p))
        if im is None:
            raise RuntimeError(f'Could not read {p}')
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        im = cv2.resize(im, (self.size, self.size), interpolation=cv2.INTER_AREA)
        return im

    def to_tensor(self, img):
        # img: HxWxC uint8 RGB -> Tensor in [-1,1] CxHxW
        t = torch.from_numpy(img.astype('float32') / 255.0).permute(2,0,1)
        t = (t - 0.5) / 0.5
        return t

    def __getitem__(self, idx):
        p = self.files[idx]
        tgt = self.load_image(p)
        inp_img = tgt.copy()
        if self.degrade is not None:
            inp_img = self.degrade.degrade(inp_img)
        if self.transform is not None:
            # allow extra transforms if provided (PIL or numpy aware)
            inp_img_pil = Image.fromarray(inp_img)
            tgt_pil = Image.fromarray(tgt)
            inp_t = self.transform(inp_img_pil)
            tgt_t = self.transform(tgt_pil)
            # transform expected to keep range in [0,1] or [-1,1] depending on implementation
            return inp_t, tgt_t
        inp_t = self.to_tensor(inp_img)
        tgt_t = self.to_tensor(tgt)
        return inp_t, tgt_t

# quick test: create dataset instance (not executed until user runs cell)
degrader = DegradeTransforms(min_scale=0.25, max_scale=0.5, blur_prob=0.0, max_blur_ks=7)
print('Degrader ready')

Degrader ready


In [14]:
# Cell 3: VAE model definition
class ConvEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, latent_dim=256):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, base_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(base_ch, base_ch*2, 4, 2, 1)
        self.conv3 = nn.Conv2d(base_ch*2, base_ch*4, 4, 2, 1)
        self.conv4 = nn.Conv2d(base_ch*4, base_ch*8, 4, 2, 1)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        # pool to 16x16 so decoder with 4 upsampling layers returns to 256x256
        self.adaptive_pool = nn.AdaptiveAvgPool2d((16,16))
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(base_ch*8*16*16, latent_dim)
        self.fc_logvar = nn.Linear(base_ch*8*16*16, latent_dim)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.act(self.conv3(x))
        x = self.act(self.conv4(x))
        x = self.adaptive_pool(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class ConvDecoder(nn.Module):
    def __init__(self, out_ch=3, base_ch=64, latent_dim=256):
        super().__init__()
        # match encoder's pooled feature size
        self.fc = nn.Linear(latent_dim, base_ch*8*16*16)
        self.deconv1 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, 2, 1)
        self.deconv3 = nn.ConvTranspose2d(base_ch*2, base_ch, 4, 2, 1)
        self.deconv4 = nn.ConvTranspose2d(base_ch, out_ch, 4, 2, 1)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.tanh = nn.Tanh()

    def forward(self, z):
        x = self.fc(z)
        # reshape to (B, base_ch*8, 16, 16)
        x = x.view(x.size(0), -1, 16, 16)
        x = self.act(self.deconv1(x))   # -> 32x32
        x = self.act(self.deconv2(x))   # -> 64x64
        x = self.act(self.deconv3(x))   # -> 128x128
        x = self.tanh(self.deconv4(x))  # -> 256x256
        return x

class SimpleVAE(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, latent_dim=256):
        super().__init__()
        self.encoder = ConvEncoder(in_ch, base_ch, latent_dim)
        self.decoder = ConvDecoder(in_ch, base_ch, latent_dim)

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()  # logvar/2 -> std (approx)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

print('Model classes defined')

Model classes defined


In [15]:
# Cell 3b: GAN models and training loop (conditional GAN: generator + PatchGAN discriminator)
class Generator(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, base_ch=64):
        super().__init__()
        # encoder
        self.e1 = nn.Conv2d(in_ch, base_ch, 4, 2, 1)       # 128
        self.e2 = nn.Conv2d(base_ch, base_ch*2, 4, 2, 1)   # 64
        self.e3 = nn.Conv2d(base_ch*2, base_ch*4, 4, 2, 1) # 32
        self.e4 = nn.Conv2d(base_ch*4, base_ch*8, 4, 2, 1) # 16
        self.act = nn.LeakyReLU(0.2, inplace=True)
        # decoder
        self.d1 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 4, 2, 1)   # 32
        self.d2 = nn.ConvTranspose2d(base_ch*8, base_ch*2, 4, 2, 1)   # 64 (concat)
        self.d3 = nn.ConvTranspose2d(base_ch*4, base_ch, 4, 2, 1)     # 128
        self.d4 = nn.ConvTranspose2d(base_ch*2, out_ch, 4, 2, 1)      # 256
        self.tanh = nn.Tanh()

    def forward(self, x):
        e1 = self.act(self.e1(x))
        e2 = self.act(self.e2(e1))
        e3 = self.act(self.e3(e2))
        e4 = self.act(self.e4(e3))
        d1 = self.act(self.d1(e4))
        d1_cat = torch.cat([d1, e3], dim=1)
        d2 = self.act(self.d2(d1_cat))
        d2_cat = torch.cat([d2, e2], dim=1)
        d3 = self.act(self.d3(d2_cat))
        d3_cat = torch.cat([d3, e1], dim=1)
        out = self.tanh(self.d4(d3_cat))
        return out

class PatchDiscriminator(nn.Module):
    def __init__(self, in_ch=6, base_ch=64):
        super().__init__()
        # input is concat(input, target) -> 6 channels
        self.conv1 = nn.Conv2d(in_ch, base_ch, 4, 2, 1)   # 128
        self.conv2 = nn.Conv2d(base_ch, base_ch*2, 4, 2, 1) #64
        self.conv3 = nn.Conv2d(base_ch*2, base_ch*4, 4, 2, 1) #32
        self.conv4 = nn.Conv2d(base_ch*4, base_ch*8, 4, 1, 1) #31->30-ish
        self.conv5 = nn.Conv2d(base_ch*8, 1, 4, 1, 1)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.act(self.conv3(x))
        x = self.act(self.conv4(x))
        x = self.conv5(x)
        return x


def train_gan(dataset_root=None, out_dir=OUT_DIR, epochs=100, batch_size=16, lr=2e-4, base_ch=64,
              adv_weight=1.0, lambda_l1=100.0, num_workers=4, save_every=1, 
              train_loader=None, use_make_dataset=True, nb_images=10, trainsplit=0.8):
    """
    Simple conditional GAN training loop. Generator maps degraded input -> enhanced output.
    Discriminator is a PatchGAN that sees (input, real) or (input, fake).
    
    Args:
        train_loader: If provided, use this DataLoader instead of creating one. Expected to return (target, label) batches.
        dataset_root: Path to dataset (only used if train_loader is None).
        use_make_dataset: If True and train_loader is None, use dataloader.make_dataset; else VaeImageDataset.
    """
    if train_loader is not None:
        loader = train_loader
        use_external_ds = True
    elif use_make_dataset:
        train_ds, _, _ = dataloader.make_dataset(dataset_root, nb_images, image_size=256, trainsplit=trainsplit, crop_faces=False)
        loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        use_external_ds = True
    else:
        ds = VaeImageDataset(dataset_root, size=256, transform=None, degrade=degrader)
        loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        use_external_ds = False

    G = Generator(in_ch=3, out_ch=3, base_ch=base_ch).to(device)
    D = PatchDiscriminator(in_ch=6, base_ch=base_ch).to(device)
    opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

    bce = nn.BCEWithLogitsLoss()
    l1 = nn.L1Loss()

    iters = 0
    for epoch in range(1, epochs+1):
        G.train(); D.train()
        t0 = time.time()
        epoch_loss_G = 0.0
        epoch_loss_D = 0.0
        for batch in loader:
            if use_external_ds:
                tgt, _ = batch
                tgt = tgt.to(device)
                # convert to uint8 and degrade
                tgt_uint8 = ((tgt + 1.0) * 127.5).clamp(0,255).permute(0,2,3,1).detach().cpu().numpy().astype('uint8')
                inp_imgs = [degrader.degrade(img) for img in tgt_uint8]
                inp = torch.stack([((torch.from_numpy(img.astype('float32')/255.0).permute(2,0,1)-0.5)/0.5) for img in inp_imgs]).to(device)
            else:
                inp, tgt = batch
                inp = inp.to(device); tgt = tgt.to(device)

            # ------------------ update D ------------------
            with torch.no_grad():
                fake = G(inp)
            real_pair = torch.cat([inp, tgt], dim=1)
            fake_pair = torch.cat([inp, fake], dim=1)

            opt_D.zero_grad()
            pred_real = D(real_pair)
            pred_fake = D(fake_pair)
            real_labels = torch.ones_like(pred_real, device=device)
            fake_labels = torch.zeros_like(pred_fake, device=device)
            loss_D_real = bce(pred_real, real_labels)
            loss_D_fake = bce(pred_fake, fake_labels)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            opt_D.step()

            # ------------------ update G ------------------
            opt_G.zero_grad()
            fake = G(inp)
            fake_pair = torch.cat([inp, fake], dim=1)
            pred_fake_for_G = D(fake_pair)
            adv_loss = bce(pred_fake_for_G, real_labels)
            recon_loss = l1(fake, tgt)
            loss_G = adv_weight * adv_loss + lambda_l1 * recon_loss
            loss_G.backward()
            opt_G.step()

            epoch_loss_G += loss_G.item()
            epoch_loss_D += loss_D.item()
            iters += 1

            if iters % 200 == 0:
                # save sample
                sample_and_save(G, inp.detach().cpu(), tgt.detach().cpu(), iters, SAMPLES_DIR)

        dt = time.time() - t0
        print(f'Epoch {epoch}/{epochs} G_loss={epoch_loss_G/len(loader):.6f} D_loss={epoch_loss_D/len(loader):.6f} time={dt:.1f}s')
        if epoch % save_every == 0:
            torch.save({'epoch': epoch, 'G': G.state_dict(), 'D': D.state_dict(), 'optG': opt_G.state_dict(), 'optD': opt_D.state_dict()}, CHECKPOINT_DIR / f'gan_epoch_{epoch}.pth')

    return G, D


In [16]:
# Cell 4: training utilities and loop
def loss_function(recon, target, mu, logvar, l1_weight=1.0, beta=1e-4):
    # recon, target in [-1,1]
    recon_loss = F.l1_loss(recon, target) * l1_weight
    # KL per batch (sum over latent dims, mean over batch)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    kl = kl.mean()
    return recon_loss + beta * kl, recon_loss.item(), kl.item()

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optim_state': optimizer.state_dict()}, str(path))

def sample_and_save(model, inp, tgt, step, out_dir):
    model.eval()
    with torch.no_grad():
        recon, _, _ = model(inp.to(device))
    # recon, inp, tgt ranges: model uses [-1,1] -> save_image will map to [0,1] if we transform back
    # Expect inp and tgt to be CPU tensors in [-1,1]
    grid = make_grid(torch.cat([inp.cpu(), recon.cpu(), tgt.cpu()], dim=0), nrow=inp.size(0))
    save_image((grid + 1) / 2.0, out_dir / f'sample_step_{step}.png')
    model.train()


def train_vae(dataset_root=None, out_dir=OUT_DIR, epochs=10, batch_size=16, lr=1e-4, latent_dim=512,
              l1_weight=1.0, beta=1e-4, num_workers=4, save_every=1, 
              train_loader=None, use_make_dataset=True, nb_images=10, trainsplit=0.8):
    """
    Train VAE. If train_loader is provided, use it; otherwise build dataset from dataset_root.
    
    Args:
        train_loader: If provided, use this DataLoader instead of creating one. Expected to return (target, label) batches.
        dataset_root: Path to dataset (only used if train_loader is None).
        use_make_dataset: If True and train_loader is None, use dataloader.make_dataset; else VaeImageDataset.
    """
    if train_loader is not None:
        loader = train_loader
        use_external_ds = True
    elif use_make_dataset:
        # use the project's dataloader to build train/test datasets. We set crop_faces=False so tensors are raw
        # resized images normalized to [-1,1] by the dataloader; we will convert them back to uint8 to apply our degrader
        train_ds, test_ds, _ = dataloader.make_dataset(dataset_root, nb_images, image_size=256, trainsplit=trainsplit, crop_faces=False)
        loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        use_external_ds = True
    else:
        ds = VaeImageDataset(dataset_root, size=256, transform=None, degrade=degrader)
        loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        use_external_ds = False

    model = SimpleVAE(in_ch=3, base_ch=64, latent_dim=latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    n_iter = 0
    for epoch in range(1, epochs+1):
        epoch_loss = 0.0
        recon_l_sum = 0.0
        kl_sum = 0.0
        start = time.time()
        for batch_idx, batch in enumerate(loader, 1):
            if use_external_ds:
                # dataloader.make_dataset -> TensorDataset of (faces, labels)
                tgt, _labels = batch
                tgt = tgt.to(device)
                # convert target tensors (assumed in [-1,1]) back to uint8 HWC for degradation
                tgt_uint8 = ((tgt + 1.0) * 127.5).clamp(0, 255).permute(0, 2, 3, 1).detach().cpu().numpy().astype('uint8')
                # apply degrader per-sample
                inp_imgs = [degrader.degrade(img) for img in tgt_uint8]
                # convert degraded inputs back to tensors in [-1,1]
                inp = torch.stack([((torch.from_numpy(img.astype('float32') / 255.0).permute(2, 0, 1) - 0.5) / 0.5) for img in inp_imgs]).to(device)
            else:
                inp, tgt = batch
                inp = inp.to(device)
                tgt = tgt.to(device)

            optimizer.zero_grad()
            recon, mu, logvar = model(inp)
            loss, recon_l, kl_l = loss_function(recon, tgt, mu, logvar, l1_weight=l1_weight, beta=beta)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            recon_l_sum += recon_l
            kl_sum += kl_l
            n_iter += 1
            if n_iter % 200 == 0:
                # save a sample (use a small minibatch from current batch)
                # prepare CPU copies of inp/tgt consistent with sample_and_save expectations
                sample_inp = inp.detach().cpu()
                sample_tgt = tgt.detach().cpu()
                sample_and_save(model, sample_inp, sample_tgt, n_iter, SAMPLES_DIR)

        elapsed = time.time() - start
        print(f'Epoch {epoch}/{epochs} avg_loss={epoch_loss/len(loader):.6f} recon={recon_l_sum/len(loader):.6f} kl={kl_sum/len(loader):.6f} time={elapsed:.1f}s')

        if epoch % save_every == 0:
            ckpt_path = CHECKPOINT_DIR / f'vae_epoch_{epoch}.pth'
            save_checkpoint(model, optimizer, epoch, ckpt_path)
            print('Saved checkpoint to', ckpt_path)

    return model

print('Training utilities defined')

Training utilities defined


In [17]:
# Cell 6: quick smoke-run / example usage (adjust params and run)
# NOTE: this cell will run training; set epochs small for a quick smoke test.
# Make sure you've run Cell 7 first to create train_loader and test_loader!

EPOCHS = 1000
BATCH_SIZE = 8
LR = 2e-4
LATENT_DIM = 256

# Use the pre-created train_loader from Cell 7 (much cleaner!)
if 'train_loader' not in dir():
    print('ERROR: train_loader not found. Please run Cell 7 first to load the dataset!')
else:
    # run GAN training using the pre-loaded train_loader
    G, D = train_gan(train_loader=train_loader, out_dir=OUT_DIR, epochs=EPOCHS, lr=LR, 
                     base_ch=64, adv_weight=1.0, lambda_l1=100.0, save_every=200)
    print('GAN training finished (smoke run)')

Epoch 1/1000 G_loss=45.015999 D_loss=0.667705 time=0.5s
Epoch 2/1000 G_loss=39.550528 D_loss=0.482376 time=0.5s
Epoch 2/1000 G_loss=39.550528 D_loss=0.482376 time=0.5s
Epoch 3/1000 G_loss=23.253866 D_loss=0.842958 time=0.5s
Epoch 3/1000 G_loss=23.253866 D_loss=0.842958 time=0.5s
Epoch 4/1000 G_loss=15.226611 D_loss=0.709538 time=0.5s
Epoch 4/1000 G_loss=15.226611 D_loss=0.709538 time=0.5s
Epoch 5/1000 G_loss=13.547769 D_loss=0.686390 time=0.5s
Epoch 5/1000 G_loss=13.547769 D_loss=0.686390 time=0.5s
Epoch 6/1000 G_loss=11.229188 D_loss=0.694160 time=0.5s
Epoch 6/1000 G_loss=11.229188 D_loss=0.694160 time=0.5s
Epoch 7/1000 G_loss=10.704336 D_loss=0.681595 time=0.5s
Epoch 7/1000 G_loss=10.704336 D_loss=0.681595 time=0.5s
Epoch 8/1000 G_loss=9.936435 D_loss=0.684250 time=0.5s
Epoch 8/1000 G_loss=9.936435 D_loss=0.684250 time=0.5s
Epoch 9/1000 G_loss=9.417183 D_loss=0.691413 time=0.5s
Epoch 9/1000 G_loss=9.417183 D_loss=0.691413 time=0.5s
Epoch 10/1000 G_loss=10.007603 D_loss=0.674935 time=

KeyboardInterrupt: 

In [18]:
# Cell 7: Load datasets using project dataloader
IMAGE_SIZE = 256
NB_IMAGES = 10
BATCH_SIZE = 32

train_dataset, test_dataset, nb_classes = dataloader.make_dataset("../data/Face-Swap-M2-Dataset/dataset/smaller", NB_IMAGES, IMAGE_SIZE, 0.8, crop_faces=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

print('Loaded dataloaders: train_loader, test_loader')

Building dataset 

Loaded dataloaders: train_loader, test_loader
Loaded dataloaders: train_loader, test_loader


In [None]:
# Cell 8: Plot GAN results
import matplotlib.pyplot as plt

def plot_gan_results(generator, loader, num_samples=5):
    generator.eval()
    inp, tgt = next(iter(loader))
    inp, tgt = inp.to(device), tgt.to(device)
    
    # Ensure num_samples does not exceed the batch size
    num_samples = min(num_samples, inp.size(0))
    
    with torch.no_grad():
        fake = generator(inp)
    
    # tensors
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 4))
    for i in range(num_samples):
        axes[i, 0].imshow(inp[i])
        axes[i, 0].set_title("Input (Degraded)")
        axes[i, 0].axis("off")
        
        axes[i, 1].imshow(fake[i])
        axes[i, 1].set_title("Generated (Enhanced)")
        axes[i, 1].axis("off")
        
        axes[i, 2].imshow(tgt[i])
        axes[i, 2].set_title("Target (Original)")
        axes[i, 2].axis("off")
    
    plt.tight_layout()
    plt.show()

# Example usage (run after training completes)
plot_gan_results(G, test_loader)

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 1 is not equal to len(dims) = 4

## Notes and next steps
- For better perceptual quality try adding a perceptual loss (VGG features) or L1 in Lab space.
- You can increase `latent_dim`, `base_ch`, and number of epochs for higher capacity.
- Save intermediate sample images to `outputs/vae_enhance/samples` and checkpoints to `outputs/vae_enhance/checkpoints`.
- If training on many images, increase `num_workers` and `batch_size` and use an LR schedule.
- To evaluate perceptual similarity, compute LPIPS or use a pre-trained face embedding model (you already have FaceNet code in `main.ipynb`).