# Generative Vision Modeling - How to Accually Train Your Variational Auto Encoder
Oğuzhan Ercan - x.com/oguzhannercan

In this chapter, we will be studying on variational auto encoders. The reason for invastigating on variational auto encoders is going from an image to latent space. The "Latent Space" will be discussed in more detailed at DDPM and DDIM sections. In this notebook, we will start to building an Auto Encoder, after that, we will build and Variational Auto Encoder and we will introduce vector quantization to these models. 

### What is Auto Encoder

_A standard Autoencoder (AE) is a neural network architecture composed of two main parts: an encoder and a decoder. The encoder takes an input and compresses it into a lower-dimensional representation called the latent space or code. This compressed representation is then fed into the decoder, which attempts to reconstruct the original input from it. The primary goal of an AE is to learn efficient and useful representations of the input data by minimizing the reconstruction error between the input and the output. - Gemini 2.0_ 

![Auto Encoder Architecture](media/ae.png)

_Figure 1: Auto Encoder Architecture_
### What is Variational Auto Encoder

_A Variational Autoencoder (VAE) is a generative model that learns a probabilistic distribution over the latent space of the input data. Unlike standard autoencoders that learn a fixed encoding, VAEs encode the input into parameters of a probability distribution, typically a Gaussian. This allows for generating new data points by sampling from this learned latent distribution and decoding it back to the input space. VAEs are particularly useful for tasks like generating realistic images and other complex data. - Gemini 2.0_

![Variational Auto Encoder Architecture](media/vae-gaussian.png)

_Figure 2: Variational Auto Encoder Architecture_

As seen in the figure below, a vae differs from an ae at latent space prediction z, instead of predicting directly z, it assumes the latent space have standart gaussian distrubition, and predicts mean and variance of the sample, then samples a data point from it. Below, we will show the difference between AE and VAE, then we will implement them.


| **Aspect**            | **Autoencoder (AE)**                                   | **Variational Autoencoder (VAE)**                       |
|-----------------------|-------------------------------------------------------|--------------------------------------------------------|
| **Objective**         | Minimize reconstruction error $$  \|x - \hat{x}\|^2  $$ | Minimize reconstruction error + KL divergence $$  D_{\text{KL}}(q(z|x) \| p(z))  $$ |
| **Latent Space**      | Deterministic: $$  z = f(x)  $$                         | Probabilistic: $$  z \sim \mathcal{N}(\mu, \sigma^2)  $$ |
| **Loss Function**     | $$  \mathcal{L} = \|x - \hat{x}\|^2  $$                | $$  \mathcal{L} = \|x - \hat{x}\|^2 - \frac{1}{2} \sum (1 + \log \sigma^2 - \mu^2 - e^{\log \sigma^2})  $$ |
| **Accuracy**          | High reconstruction fidelity                         | Moderate fidelity due to regularization               |
| **Use Cases**         | Data compression, denoising, feature extraction      | Generative modeling, data synthesis, anomaly detection |
| **Generative Ability**| No                                                   | Yes, via sampling $$  z \sim p(z)  $$                    |
| **Complexity**        | Simpler optimization                                 | Increased complexity with KL term                     |

### Architectural Design Choices

In the following sections, we will briefly describe the encoder - decoder architectures.

For simplicity, we will first build an Encoder - Decoder architecture with basic convolution layers, after that, we will build stronger architectures that can capture better feature and details.

As seen at figure 1, an auto encoder takes a data, in our cases this will be an image. More spesifically we will be working on ImageNet dataset. Encoder applies some transformations, and after these transformations, we can see that the channel size increases and width x height decreases, which is a typical conv2d transformation with appropiate hyper parameters. So we need to decide what will be the shape of encoder's output, which we call latent vector. A typical latent vector for an basic autoencoder is Batch_Size x 512 x 16 x 16, so we will take an image with shape 3 x 64 x 64 (channel x width x height) and convert it the batch_sizex512x4x4.  




In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self,):
        super(Encoder, self).__init__()
        
        # bx3x64x64 -> bx64x32x32 
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(64)
        # bx64x32x32 -> bx128x16x16
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(128)
        # bx128x16x16 -> bx256x8x8
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(256)
        # bx256x32x32 -> bx512x4x4
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(512)
        

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
        return x

In [2]:
tensor = torch.randn(1, 3, 64, 64)
encoder = Encoder()
print(encoder(tensor).shape)



In [3]:
class Decoder(nn.Module):
    def __init__(self,):
        super(Decoder, self).__init__()
        
        # bx512x4x4 -> bx256x8x8
        self.conv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(256)
        # bx256x8x8 -> bx128x16x16
        self.conv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(128)
        # bx128x16x16 -> bx64x32x32
        self.conv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(64)
        # bx64x32x32 -> bx3x64x64
        self.conv4 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
        

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = torch.tanh(self.conv4(x))
        return x

In [4]:
tensor = torch.randn(1, 512, 4, 4)
decoder = Decoder()
print(decoder(tensor).shape)



In [5]:
class AE(nn.Module):
    def __init__(self,):
        super(AE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [6]:
tensor = torch.randn(1, 3, 64, 64)
ae = AE()
print(ae(tensor).shape)



Above, we build an auto encoder, which first encodes an image with shape 64x64x3 to latent vector with shape 512x4x4 (which is not making sense, we will discuss this later), then decodes back it to image space. Now we will train it with a few samples. 

In [7]:
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torch.optim import AdamW
from torch.nn import L1Loss
def basic_train_config(model,num_train, num_val,batch_size):
    transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.ImageFolder("/home/oguzhan/Downloads/ImageNetDataset/", transform=transform)
    subset = Subset(dataset, range(num_train))
    subset_val = Subset(dataset, range(num_train,num_train+num_val))
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(subset_val, batch_size=10, shuffle=False)
    optimizer = AdamW(model.parameters(), lr=0.0002)
    criterion = L1Loss()
    return dataloader, valloader, optimizer, criterion

In [None]:
from tqdm import tqdm
def train(model, dataloader, valloader, optimizer, criterion):
    outs = []  # MOVED OUTSIDE THE LOOP
    for epoch in tqdm(range(20)):
        model.train()
        for i, (images, _) in enumerate(dataloader):
            images = images.to("cuda")
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()

        #validate
        ae.eval()
        if (epoch+1) % 2 == 0:
            with torch.no_grad():
                for i, (images, _) in enumerate(valloader):
                    images = images.to("cuda")
                    outputs = ae(images)
                    outputs = outputs.cpu().numpy()
                    outs.append(outputs)
    return outs

In [None]:
from tqdm import tqdm
ae = AE().to("cuda")
dataloader,valloader,optimizer,criterion = basic_train_config(ae,1000,10,32)


In [None]:
outs = train(ae, dataloader, valloader, optimizer, criterion)

In [10]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import torch
from torchvision.transforms import ToPILImage
import os
def denormalize(tensor):
    return (tensor.transpose(1, 2, 0) + 1) * 127.5

def get_ground_truth_images(dataloader):
    gt_images = []
    
    for batch in dataloader:
        imgs, _ = batch  # Extract images from batch
        imgs = imgs.numpy()
        imgs = [denormalize(img) for img in imgs]  # Apply denormalization to each image
        
        gt_images.extend(imgs)
        if len(gt_images) >= 10:  # Stop once we have 10 images
            break
    
    return np.array(gt_images[:10])  

#ground_truth = get_ground_truth_images(valloader)


def visualize_epoch(epoch_idx):
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    plt.suptitle(f'Epoch {(epoch_idx + 1) * 10}')
    
    predictions = outs[epoch_idx]
    
    # Verify predictions change
    if epoch_idx > 0:
        diff = np.mean(np.abs(outs[epoch_idx] - outs[epoch_idx-1]))
        print(f"Mean difference from previous epoch: {diff}")
    
    for i in range(5):
        gt_img = ground_truth[i]
        axes[0, i].imshow(gt_img.astype(np.uint8))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Ground Truth')
            
        pred_img = denormalize(predictions[i])
        axes[1, i].imshow(pred_img.astype(np.uint8))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Prediction')
    
    plt.tight_layout()
    plt.show()


In [None]:
interact(visualize_epoch, 
         epoch_idx=(0, len(outs)-1, 1))

As we see above, the results are blurry. The number of training samples and number of epochs might be limited, but you can trust me about that even we scale training tokens 1 million times, results will be still blurry (I already tried it :). The architecture of autoencoders and training stratgy that we used causes this problem. Solution for blurrines changes from task to task, if we want to train an image to image model, for example takes an image of person then changes features of face (from black hair to blonde hair) we can introduce skip connections from encoder layers to decoder layers. This connections helps our model to retain features relevant to the details. Also architecturel design choices like using batch normalization is causes blurriness but we will not discuss it now. In our case "Generative Vision Modeling", our main focus will be sampling a latent vector from a spesific distrubution and generating images from that. If you do not understrand the latest sentence, do not worry about it. We will discuss it later. In this case, we will not be using encoder at all, so we need to come up with different solutions.

As you can guess, we will work on variational auto encoders now. We have already discussed the difference between AE and VAE, but here is my most intuitive explanation between them:

### _AE's encoder is a compression algorithm - its decoder is the reconstruction model, but VAE's encoder is a image to latent space mapper and its decoder is the conditional image generation model._

In [None]:
import torch
import torch.nn as nn
class VAE_Encoder(nn.Module):
    def __init__(self):
        super(VAE_Encoder, self).__init__()
        
        # bx3x64x64 -> bx64x32x32 
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(64)
        # bx64x32x32 -> bx128x16x16
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(128)
        # bx128x16x16 -> bx256x8x8
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(256)
        # bx256x32x32 -> bx512x4x4
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        #mean predictor
        self.fc1 = nn.Linear(512*4*4, 256)
        #logvar predictor
        self.fc2 = nn.Linear(512*4*4, 256)
        #latent space
        self.fc3 = nn.Linear(256, 512*4*4)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = x.view(x.size(0), -1)
        mu = self.fc1(x)
        logvar = self.fc2(x)
        x = self.fc3(mu)
        x = x.view(x.size(0), 512, 4, 4)
        x = self.reparameterize(mu, logvar)
        return x, mu, logvar

class VAE_Decoder(nn.Module):
    def __init__(self):
        super(VAE_Decoder, self).__init__()

        self.fc = nn.Linear(256, 512 * 4 * 4)  # Expanding latent vector

        self.conv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(256)

        self.conv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(128)

        self.conv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.ConvTranspose2d(64, 3, 4, 2, 1)

    def forward(self, z):
        z = self.fc(z)  # Convert (batch_size, 256) -> (batch_size, 512 * 4 * 4)
        z = z.view(z.size(0), 512, 4, 4)  # Reshape to 4D tensor for ConvTranspose

        z = F.leaky_relu(self.bn1(self.conv1(z)), 0.2)
        z = F.leaky_relu(self.bn2(self.conv2(z)), 0.2)
        z = F.leaky_relu(self.bn3(self.conv3(z)), 0.2)
        pre_activation = self.conv4(z)
        print("Pre-activation (min, max, mean):", pre_activation.min().item(), pre_activation.max().item(), pre_activation.mean().item())
        z = torch.tanh(pre_activation)

        return z


In [9]:
encoder = VAE_Encoder()
decoder = VAE_Decoder()

In [10]:
tensor = torch.randn(1, 3, 64, 64)
vae_encoder = VAE_Encoder()
x,mu,lovar = vae_encoder(tensor)
print(x.shape)
print(mu.shape)
print(lovar.shape)



In [11]:
import torch.nn.functional as F
class VAE(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.encoder = VAE_Encoder()
        self.decoder = VAE_Decoder()
    def forward(self, x):
        x, mu, logvar = self.encoder(x)
        x = self.decoder(x)
        return x, mu, logvar
vae = VAE()
tensor = torch.randn(1, 3, 64, 64)
x, mu, logvar = vae(tensor)
print(x.shape)



In [12]:
vae = VAE()
tensor = torch.randn(1, 3, 64, 64)
x, mu, logvar = vae(tensor)
print(x.shape)



In [None]:
def vae_loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (MSE)
    reconstruction_loss = F.mse_loss(recon_x, x, reduction='sum')
# KL Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + KLD

In [None]:
def train_vae(model, dataloader, valloader, optimizer,criterion):
    outs = [] 
    for epoch in tqdm(range(20)):
        model.train()
        for i, (images, _) in enumerate(dataloader):
            images = images.to("cuda")
            optimizer.zero_grad()
            outputs,mu,logvar = model(images)
            loss = criterion(outputs, images,mu,logvar)
            loss.backward()
            optimizer.step()

        #validate
        ae.eval()
        if (epoch+1) % 2 == 0:
            with torch.no_grad():
                for i, (images, _) in enumerate(valloader):
                    images = images.to("cuda")
                    outputs = ae(images)
                    outputs = outputs.cpu().numpy()
                    outs.append(outputs)
    return outs

In [None]:
vae.to("cuda")
outs = train_vae(vae, dataloader, valloader, optimizer,vae_loss_function)

In [None]:
interact(visualize_epoch, 
         epoch_idx=(0, len(outs)-1, 1))

As we can see above, the resulting images are still blurry, and we scale the training tokens 1 million times, again, it will not be clear images. The reasons for that are listed below:

- The KLD term in VAEs encourages the latent distribution to be close to a simple prior (usually a Gaussian). If the weight of the KLD term is too high, it can force the model to learn a very smooth and compressed latent space that discards fine details present in the sharp input.

- The model might learn a latent representation that captures the general structure but averages out the high-frequency details to minimize the overall pixel error when using L1 or MSE losses. This averaging effect results in a blurry output.

- KLD can lead to over-regularization, causing information loss and blurry outputs.


For this reason, we will improve our loss function. The table below shows what are these loss terms and their usage purpose.   

| Loss Function          | Applicable to | Description                                                                                                                              | How it Improves Sharpness                                                                                                                                                                                                                            |
| :--------------------- | :------------ | :--------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Perceptual Loss** | AE, VAE       | Compares high-level features from a pre-trained network. Usally we use a pretrained VGG network that trained on Imagenet classfication task.                                                                                   | Encourages sharper outputs by focusing on perceptually important features like edges and textures.                                                                                                                                                  |
| **LPIPS** | AE, VAE       | Learned metric for perceptual similarity using pre-trained CNN features.                                                                             | Directly promotes perceptually similar sharp outputs by aligning with human visual perception of image quality.                                                                                                                                  |
| **SSIM** | AE, VAE       | Measures structural similarity based on luminance, contrast, and structure.                                                                           | Helps preserve edges and structural details, leading to less blurry and more structurally accurate reconstructions.                                                                                                                               |
| **Reconstruction Loss** | AE, VAE       | Pixel-wise difference (e.g., MSE, L1) between output and target.                                                                         | Ensures basic similarity but often needs other losses for sharpness, as it can lead to blurry results by averaging details.                                                                                                                            |
| **KLD** | VAE           | Measures the difference between the learned latent distribution and a prior.                                                                        | When balanced, helps learn a meaningful latent space. Over-regularization can cause blurriness; reducing its weight can sometimes improve sharpness.                                                                                              |
| **GAN Loss** | AE (hybrid), VAE (hybrid), GAN | Based on a discriminator network that tries to distinguish between real and generated images. The generator tries to fool the discriminator. | Encourages the generator to produce more realistic and sharper images that are indistinguishable from real sharp images to the discriminator. This adversarial process pushes the generator to learn fine details and textures, leading to sharper outputs. |


As seen table above, LPIPS ans Perceptual losses seems almost same. Here is the difference between them:

| Aspect          | Perceptual Loss                               | LPIPS (Learned Perceptual Image Patch Similarity)                                  |
| :--------------- | :-------------------------------------------- | :---------------------------------------------------------------------------------- |
| **Definition** | General category of loss functions using pre-trained CNN features. | Specific, learned metric for perceptual similarity using pre-trained CNN features. |
| **Learning** | Typically uses a fixed, pre-trained network. | Uses a pre-trained network, but **learns weights** for its feature maps based on human perception data. |
| **Feature Weights** | Implicit through the pre-trained network's learned weights (fixed during loss calculation). | Explicitly learns channel-wise weights for feature maps to align with human similarity judgments. |
| **Standardization** | Implementation can vary (network, layers, distance). | Provides a more standardized metric with pre-defined (learned) weights.             |
| **Output** | A loss value representing the distance between feature maps. | A similarity score (lower is more similar) that correlates with human perceptual similarity. |


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
import lpips
from pytorch_msssim import SSIM
from tqdm import tqdm

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 1. Implement Individual Loss Functions ---

# Reconstruction Loss 
def reconstruction_loss(outputs, images, reduction='mean'):
    return F.mse_loss(outputs, images, reduction=reduction)

# KLD Loss
def kld_loss(mu, logvar, reduction='mean'):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) if reduction == 'sum' else torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))

# Perceptual Loss (using VGG)
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PerceptualLoss(nn.Module):
    def __init__(self, layer='relu2_2'):
        super(PerceptualLoss, self).__init__()
        vgg16 = models.vgg16(pretrained=True).features.to(device).eval()
        self.selected_layer = layer
        self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        self.layer_indices = {
            'relu1_1': 0,
            'relu1_2': 2,
            'relu2_1': 5,
            'relu2_2': 7,
            'relu3_1': 10,
            'relu3_2': 12,
            'relu3_3': 14,
            'relu4_1': 17,
            'relu4_2': 19,
            'relu4_3': 21,
            'relu5_1': 24,
            'relu5_2': 26,
            'relu5_3': 28,
        }

        if layer not in self.layer_indices:
            raise ValueError(f"Layer {layer} not found in supported VGG layers.")

        self.selected_layer_index = self.layer_indices[layer]
        self.model = vgg16

    def forward(self, generated, target):
        def get_features(image, model, layer_index):
            x = self.transform(image)
            features = []
            for i, layer in enumerate(model):
                x = layer(x)
                if i == layer_index:
                    features.append(x)
                    break
            return features

        generated_features = get_features(generated, self.model, self.selected_layer_index)
        target_features = get_features(target, self.model, self.selected_layer_index)

        loss = torch.mean((generated_features[0] - target_features[0]) ** 2)
        return loss

# LPIPS Loss
class LPIPSLoss(nn.Module):
    def __init__(self, net='vgg', version='0.1'):
        super(LPIPSLoss, self).__init__()
        self.lpips_fn = lpips.LPIPS(net=net, version=version).to(device)

    def forward(self, generated, target):
        return torch.mean(self.lpips_fn(generated, target))

# SSIM Loss


class SSIMLoss(nn.Module):
    def __init__(self, data_range=1.0, size_average=True, win_size=11, win_sigma=1.5, channel=3):
        super(SSIMLoss, self).__init__()
        self.ssim = SSIM(data_range=1.0, size_average=True, channel=3).to(device)
    def forward(self, img1, img2):
        return 1 - self.ssim(img1, img2)
    
# GAN Loss (for the Generator - VAE's Decoder)
class LSGANLoss(nn.Module):
    def __init__(self):
        super(LSGANLoss, self).__init__()

    def discriminator_loss(self, real_output, fake_output):
        # LSGAN uses MSE loss instead of BCE
        real_loss = torch.mean((real_output - 1) ** 2)
        fake_loss = torch.mean(fake_output ** 2)
        return (real_loss + fake_loss) * 0.5

    def generator_loss(self, fake_output):
        # Generator tries to make fake output close to 1
        return torch.mean((fake_output - 1) ** 2)


# --- 2. Build a Combined Criterion Function ---

def combined_criterion(outputs, images, mu, logvar, loss_config,
                      perceptual_criterion=None, lpips_criterion=None, ssim_criterion=None):
    total_loss = 0

    if 'reconstruction' in loss_config and loss_config['reconstruction']['weight'] > 0:
        recon_loss = reconstruction_loss(outputs, images)
        total_loss += loss_config['reconstruction']['weight'] * recon_loss

    if 'kld' in loss_config and loss_config['kld']['weight'] > 0:
        kld = kld_loss(mu, logvar)
        total_loss += loss_config['kld']['weight'] * kld

    if 'perceptual' in loss_config and loss_config['perceptual']['weight'] > 0:
        perc_loss = perceptual_criterion(outputs, images)
        total_loss += loss_config['perceptual']['weight'] * perc_loss

    if 'lpips' in loss_config and loss_config['lpips']['weight'] > 0:
        lpips_loss = lpips_criterion(outputs, images)
        total_loss += loss_config['lpips']['weight'] * lpips_loss

    if 'ssim' in loss_config and loss_config['ssim']['weight'] > 0:
        ssim_loss = ssim_criterion(outputs, images)
        total_loss += loss_config['ssim']['weight'] * ssim_loss

    return total_loss


# Define a simple Discriminator for GAN loss
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(PatchDiscriminator, self).__init__()
        
        def conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            if use_bn:
                layers.insert(1, nn.InstanceNorm2d(out_channels))
            return nn.Sequential(*layers)

        # PatchGAN-style discriminator with multiple output patches
        self.conv1 = conv_block(input_channels, 64, use_bn=False)  # 64x32x32
        self.conv2 = conv_block(64, 128)                         # 128x16x16
        self.conv3 = conv_block(128, 256)                        # 256x8x8
        self.conv4 = conv_block(256, 512)                        # 512x4x4
        self.conv5 = nn.Conv2d(512, 1, 4, 1, 0)                 # 1x1x1 (patch outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return torch.sigmoid(x)  # Output patch-wise probabilities


# --- 3. Modify the Training Function ---
def train_vae(model, dataloader, valloader, optimizer, loss_config, discriminator=None, 
              discriminator_optimizer=None, num_epochs=20,
              perceptual_criterion=None, lpips_criterion=None, ssim_criterion=None, 
              run_name="default_run"):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gan_loss_fn = LSGANLoss()
    
    output_root_dir = f"jn_outs/{run_name}"
    gt_dir = os.path.join(output_root_dir, "ground_truth")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(output_root_dir, exist_ok=True)
    to_pil = ToPILImage()
    ground_truth_saved = False

    # Save ground truth images
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(valloader):
            if not ground_truth_saved:
                images_denorm = denormalize(images).cpu()
                for i in range(min(10, images_denorm.size(0))):
                    img_tensor = images_denorm[i]
                    img_pil = to_pil(img_tensor)
                    img_pil.save(os.path.join(gt_dir, f"gt_image_{batch_idx * valloader.batch_size + i}.png"))
                if (batch_idx + 1) * valloader.batch_size >= 10:
                    ground_truth_saved = True
                    break

    # Training loop
    for epoch in tqdm(range(num_epochs)):
        model.train()
        for i, (images, _) in enumerate(dataloader):
            images = images.to(device)
            
            # VAE/Generator forward pass
            optimizer.zero_grad()
            outputs, mu, logvar = model(images)
            
            # Calculate combined loss excluding GAN
            total_loss = combined_criterion(outputs, images, mu, logvar, loss_config,
                                          perceptual_criterion=perceptual_criterion,
                                          lpips_criterion=lpips_criterion,
                                          ssim_criterion=ssim_criterion)

            # GAN training if enabled
            if 'gan' in loss_config and loss_config['gan']['weight'] > 0 and discriminator is not None:
                # Discriminator training
                discriminator_optimizer.zero_grad()
                real_output = discriminator(images)
                fake_output = discriminator(outputs.detach())
                d_loss = gan_loss_fn.discriminator_loss(real_output, fake_output)
                d_loss.backward()
                discriminator_optimizer.step()

                # Generator GAN loss
                fake_output = discriminator(outputs)
                g_gan_loss = gan_loss_fn.generator_loss(fake_output)
                total_loss += loss_config['gan']['weight'] * g_gan_loss

            total_loss.backward()
            optimizer.step()

        # Validation and saving
        model.eval()
        if epoch % 2 == 0:
            with torch.no_grad():
                epoch_output_dir = os.path.join(output_root_dir, f"epoch_{epoch}")
                os.makedirs(epoch_output_dir, exist_ok=True)
                for i, (images, _) in enumerate(valloader):
                    images = images.to(device)
                    outputs, _, _ = model(images)
                    outputs_denorm = denormalize(outputs).cpu()
                    for j in range(outputs_denorm.size(0)):
                        img_tensor = outputs_denorm[j]
                        img_pil = to_pil(img_tensor)
                        img_pil.save(os.path.join(epoch_output_dir, f"prediction_{i * valloader.batch_size + j}.png"))

    return None

In [None]:
import os 
input_channels = 3
image_size = 64
input_dim = input_channels * image_size * image_size
vae = VAE().to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

discriminator = PatchDiscriminator(input_channels=3).to(device)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
perceptual_criterion = PerceptualLoss(layer='relu3_3').to(device)
lpips_criterion = LPIPSLoss(net='vgg').to(device)
ssim_criterion = SSIMLoss(data_range=1.0).to(device)
dataloader,valloader,_,_ = basic_train_config(vae,5000,10,256)

# Define the loss configuration
loss_config = {
    'reconstruction': {'weight': 1.0},
    'kld': {'weight': 0.01},
    'perceptual': {'weight': 0.1, 'layer': 'relu3_3'},
    'lpips': {'weight': 0.5, 'net': 'vgg'},
    'ssim': {'weight': 0.2},
    'gan': {'weight': 0.05},
}

# Train the VAE with the combined criterion
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
                               num_epochs=10,
                               perceptual_criterion=perceptual_criterion, # Pass the initialized objects
                               lpips_criterion=lpips_criterion,
                               ssim_criterion=ssim_criterion)








In [10]:
from PIL import Image as PILImage
from torchvision.transforms import ToPILImage

### Here, we are updating visualization function so it can read from the output directory
def denormalize(tensor):
    return (tensor + 1) / 2  # Assuming your normalization was to [-1, 1]

def visualize_epoch(epoch_idx, run_name="default_run"):
    print(epoch_idx)
    output_dir = f"jn_outs/{run_name}/epoch_{epoch_idx * 2}" # Outputs saved every 2 epochs
    gt_dir = f"jn_outs/{run_name}/ground_truth"
    if not os.path.exists(output_dir):
        print(f"Error: Prediction directory not found at {output_dir}. Please run training first.")
        return
    if not os.path.exists(gt_dir):
        print(f"Error: Ground truth directory not found at {gt_dir}. Please run training first.")
        return

    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    plt.suptitle(f'Epoch {(epoch_idx ) * 2}')

    for i in range(5):
        gt_filename = os.path.join(gt_dir, f"gt_image_{i}.png")
        if os.path.exists(gt_filename):
            gt_img = PILImage.open(gt_filename)
            axes[0, i].imshow(gt_img)
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_title('Ground Truth')
        else:
            axes[0, i].axis('off')
            axes[0, i].text(0.5, 0.5, 'GT\nNot Found', ha='center', va='center')

        pred_filename = os.path.join(output_dir, f"prediction_{i}.png")
        if os.path.exists(pred_filename):
            pred_img = PILImage.open(pred_filename)
            axes[1, i].imshow(pred_img)
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_title('Prediction')
        else:
            axes[1, i].axis('off')
            axes[1, i].text(0.5, 0.5, 'Prediction\nNot Found', ha='center', va='center')

    plt.tight_layout()
    plt.show()

In [19]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import torch
from torchvision.transforms import ToPILImage
import os
num_epochs = 8 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae")







The training results above shows that we are facing a common issue caused by GAN loss. The output appears to be a noisy, grid-like pattern with no discernible structure, which suggests that the generator (in this case, the VAE's decoder) is failing to produce meaningful images. This is often a symptom of poor loss balancing, mode collapse, or issues with the training dynamics between the generator and discriminator. Below, the loss configuration is noted:

- Reconstruction Loss (MSE): Weight = 1.0
- KL Divergence (KLD): Weight = 0.01
- Perceptual Loss: Weight = 0.1
- LPIPS Loss: Weight = 0.5
- SSIM Loss: Weight = 0.2
- GAN Loss: Weight = 0.05


To fix this issue, we will balance the and try to train it again. Here is an important insight:

### _One of the most challenging part of training a generative vision model is balancing the loss function._

In [None]:
import os 
input_channels = 3
image_size = 64
input_dim = input_channels * image_size * image_size
vae = VAE().to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

discriminator = PatchDiscriminator(input_channels=3).to(device)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
perceptual_criterion = PerceptualLoss(layer='relu3_3').to(device)
lpips_criterion = LPIPSLoss(net='vgg').to(device)
ssim_criterion = SSIMLoss(data_range=1.0).to(device)
dataloader,valloader,_,_ = basic_train_config(vae,5000,10,256)

# Define the loss configuration
loss_config = {
    'reconstruction': {'weight': 0.1},  # Reduce the weight of MSE
    'kld': {'weight': 0.05},           # Slightly increase KLD for better latent space regularization
    'perceptual': {'weight': 0.5},     # Increase perceptual loss to focus on high-level features
    'lpips': {'weight': 0.5},          # Keep LPIPS as is
    'ssim': {'weight': 0.2},           # Keep SSIM as is
    'gan': {'weight': 1.0},            # Significantly increase GAN loss to encourage realistic outputs
}

# Train the VAE with the combined criterion
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
                               num_epochs=10,
                               perceptual_criterion=perceptual_criterion, # Pass the initialized objects
                               lpips_criterion=lpips_criterion,
                               ssim_criterion=ssim_criterion)


In [20]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import torch
from torchvision.transforms import ToPILImage
import os
num_epochs = 25 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced")







At this time, as we can see, all of the images are black. Now we need to fix this issue, The issue of black outputs (likely all pixels being close to 0 after denormalization, resulting in a black image) from the first epoch after the loss balancing update, compared to after the 10th epoch before the update, suggests a problem in the training dynamics or the loss computation that causes the generator (VAE's decoder) to collapse its output to a trivial solution (e.g., all zeros). 

Since we use tanh (it is just a design choice, we could choose sigmoid etc.) as the last layer of decoder, the output range is [-1,1], and our denormalization function is formulated as denormalized = (tensor + 1) / 2. The black output shows that our values are all zeros. And if you look at the results of our first experiment, you can see that last saved epoch's results are also black, suggesting a gradual collapse of the generator's output. After the update, the outputs are black from the first epoch, indicating that the changes we made (e.g., rebalancing the loss weights, adding gradient penalty, label smoothing, etc.) have accelerated this collapse. Lets analyze our changes:

- The reconstruction loss weight was reduced from 1.0 to 0.1, meaning the generator is less incentivized to match the input pixel-wise. While this was intended to reduce blurriness, it might have removed a stabilizing force that was keeping the generator’s output from collapsing.
- The GAN loss weight was increased from 0.05 to 1.0, giving the adversarial loss a much stronger influence. If the discriminator becomes too strong too quickly, it can overpower the generator, causing the generator to fail to produce meaningful outputs.
- The KLD weight was increased from 0.01 to 0.05, which might impose stronger regularization on the latent space, potentially constraining the generator’s ability to explore diverse outputs.

As I said before, loss balancing is an important problem, here is the current loss contribition of each criterion:

Epoch 1/25:  12%| | 45/390 00:40<05:17,  1.09it/s, Recon=0.6670, KLD=2566.3564, Perceptual=606.7762, LPIPS=0.7647, SSIM=0.9983, GAN_G=0.8666,


and after a few epochs:

Epoch 4/25:  47%|█▉  | 183/390 [02:34<02:52,  1.20it/s, Recon=nan, KLD=nan, Perceptual=nan, LPIPS=nan, SSIM=nan, GAN_G=nan, GAN_D=nan, Total=nan]


KLD loss gradually icreased and became nan, after that, all of the training progress has collapsed. The reason for that is logvar term gets bigger and bigger. And since the formulation of kld is includes exponent of logvar, it did not fit into float 32. We will make sure that logvar does not get to big. 

We will apply things below:

- Lowering the discriminators learning speed (lower learning rate ): Giving flexible time for generator to output a little bit realistic images so it can compete with discriminator
- Increasing label smoothing: We will increase it because as we discussed, we want to help generator.  
- Increasing reconstruction weight: We can say that reconstruction is important because we want an output image which is identical with input image, since we decreased it and see worse results, it is somehow helps model to collapse in a local minima which is close to initial point.
- KLD Loss Warmup: To prevent the KLD loss from dominating early in training, we will add a warm-up phase where the KLD weight starts at 0 and gradually increases to its target value over the first few epochs.
- Architectural changes for stability: For some reason I do not know , mu and logvar becomes extremly huge for this training script, huge logvar values are common issue at vae training, but mostly it effects the loss calculation but in this trainign script, it breaks numerical stability before loss calculation, at reparametrization stage. With some experiments, I can say that replacing batch norm with group norm and changing conv kernel size 4 -> 3 fixes this issue. 
- Logvar and mu clampling: To prevent huge KLD loss, we will clamp logvar and mu if they exceed -20 or 20.


Below, there will be some duplicated codes, because I want to show our updates clearly.


In [8]:
#### Whole imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
import lpips
from pytorch_msssim import SSIM
from tqdm import tqdm
from torchvision.transforms import ToPILImage, transforms
from PIL import Image as PILImage
import os 
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
from torch.nn import L1Loss
import matplotlib.pyplot as plt
from ipywidgets import interact

In [3]:
class Downsample(nn.Module):
    """Downsampling layer using convolution with stride 2."""
    def __init__(self, in_channels: int,out_channels: int = None):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pad = (0, 1, 0, 1)  # Pad to maintain spatial dims before strided conv
        x = nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.conv(x)
        return x
class VAE_Encoder(nn.Module):
    def __init__(self):
        super(VAE_Encoder, self).__init__()
        
        # bx3x64x64 -> bx64x32x32 
        self.conv1 = Downsample(3,64)
        self.gn1 = nn.GroupNorm(num_groups=32, num_channels=64)  # 64 channels, 32 groups
        # bx64x32x32 -> bx128x16x16
        self.conv2 = Downsample(64,128)
        self.gn2 = nn.GroupNorm(num_groups=32, num_channels=128)  # 128 channels, 32 groups
        # bx128x16x16 -> bx256x8x8
        self.conv3 = Downsample(128,256)
        self.gn3 = nn.GroupNorm(num_groups=32, num_channels=256)  # 256 channels, 32 groups
        # bx256x32x32 -> bx512x4x4
        self.conv4 = Downsample(256,512)
        # mean predictor
        self.fc1 = nn.Linear(512*4*4, 256)
        # logvar predictor
        self.fc2 = nn.Linear(512*4*4, 256)

        # Initialize weights with smaller variance
        nn.init.normal_(self.fc1.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.fc2.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = F.leaky_relu(self.gn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.gn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.gn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = x.view(x.size(0), -1)  # Shape: (B, 512*4*4)
        mu = self.fc1(x)           # Shape: (B, 256)
        logvar = self.fc2(x)       # Shape: (B, 256)
        # Use clamped values for reparameterization
        #mu_clamped = torch.clamp(mu, min=-10, max=10)
        #logvar_clamped = torch.clamp(logvar, min=-10, max=10)
        z = self.reparameterize(mu, logvar)  # Shape: (B, 256)
        return z, mu,logvar#mu_clamped, logvar_clamped

class VAE_Decoder(nn.Module):
    def __init__(self):
        super(VAE_Decoder, self).__init__()

        self.fc = nn.Linear(256, 512 * 4 * 4)  # Expanding latent vector

        self.conv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.gn1 = nn.GroupNorm(num_groups=32, num_channels=256)  # 256 channels, 32 groups

        self.conv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.gn2 = nn.GroupNorm(num_groups=32, num_channels=128)  # 128 channels, 32 groups

        self.conv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.gn3 = nn.GroupNorm(num_groups=32, num_channels=64)  # 64 channels, 32 groups

        self.conv4 = nn.ConvTranspose2d(64, 3, 4, 2, 1)

    def forward(self, z):
        z = self.fc(z)  # Convert (batch_size, 256) -> (batch_size, 512 * 4 * 4)
        z = z.view(z.size(0), 512, 4, 4)  # Reshape to 4D tensor for ConvTranspose

        z = F.leaky_relu(self.gn1(self.conv1(z)), 0.2)
        z = F.leaky_relu(self.gn2(self.conv2(z)), 0.2)
        z = F.leaky_relu(self.gn3(self.conv3(z)), 0.2)
        pre_activation = self.conv4(z)
        print("Pre-activation (min, max, mean):", pre_activation.min().item(), pre_activation.max().item(), pre_activation.mean().item())
        z = torch.tanh(pre_activation)

        return z

class VAE(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.encoder = VAE_Encoder()
        self.decoder = VAE_Decoder()
    def forward(self, x):
        x, mu, logvar = self.encoder(x)
        x = self.decoder(x)
        return x, mu, logvar

In [5]:
def reconstruction_loss(outputs, images, reduction='mean'):
    """
    Compute the reconstruction loss (MSE) between outputs and images.
    Args:
        outputs: Tensor of shape (B, C, H, W)
        images: Tensor of shape (B, C, H, W)
        reduction: 'mean' (per-sample) or 'sum' (batch-level)
    Returns:
        Per-sample loss if reduction='mean', batch-level loss if reduction='sum'
    """
    return F.mse_loss(outputs, images, reduction=reduction)

# KLD Loss
def kld_loss(mu, logvar, reduction='mean'):
    """
    Compute the KL divergence loss.
    Args:
        mu: Tensor of shape (B, latent_dim)
        logvar: Tensor of shape (B, latent_dim)
        reduction: 'mean' (per-sample) or 'sum' (batch-level)
    Returns:
        Per-sample loss if reduction='mean', batch-level loss if reduction='sum'
    """
    mu = torch.clamp(mu, min=-20, max=20)
    logvar = torch.clamp(logvar, min=-20, max=20)  # Prevent numerical instability
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return kld.sum() if reduction == 'sum' else kld.mean()

# Perceptual Loss (with proper normalization)
class PerceptualLoss(nn.Module):
    def __init__(self, layer='relu3_3'):
        super(PerceptualLoss, self).__init__()
        vgg16 = models.vgg16(pretrained=True).features.to(device).eval()
        self.selected_layer = layer
        self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        self.layer_indices = {
            'relu1_1': 0, 'relu1_2': 2, 'relu2_1': 5, 'relu2_2': 7,
            'relu3_1': 10, 'relu3_2': 12, 'relu3_3': 14,
            'relu4_1': 17, 'relu4_2': 19, 'relu4_3': 21,
            'relu5_1': 24, 'relu5_2': 26, 'relu5_3': 28,
        }

        if layer not in self.layer_indices:
            raise ValueError(f"Layer {layer} not found in supported VGG layers.")

        self.selected_layer_index = self.layer_indices[layer]
        self.model = vgg16

    def forward(self, generated, target):
        # Denormalize from [-1, 1] to [0, 1]
        generated = (generated + 1) / 2
        target = (target + 1) / 2
        # Apply VGG normalization
        generated = self.transform(generated)
        target = self.transform(target)

        def get_features(image, model, layer_index):
            x = image
            for i, layer in enumerate(model):
                x = layer(x)
                if i == layer_index:
                    return x
            return x

        generated_features = get_features(generated, self.model, self.selected_layer_index)
        target_features = get_features(target, self.model, self.selected_layer_index)

        loss = torch.mean((generated_features - target_features) ** 2)
        return loss

# LPIPS Loss
class LPIPSLoss(nn.Module):
    def __init__(self, net='vgg', version='0.1'):
        super(LPIPSLoss, self).__init__()
        self.lpips_fn = lpips.LPIPS(net=net, version=version).to(device)

    def forward(self, generated, target):
        return torch.mean(self.lpips_fn(generated, target))

# SSIM Loss


class SSIMLoss(nn.Module):
    def __init__(self, data_range=1.0, size_average=True, win_size=11, win_sigma=1.5, channel=3):
        super(SSIMLoss, self).__init__()
        self.ssim = SSIM(data_range=1.0, size_average=True, channel=3).to(device)
    def forward(self, img1, img2):
        return 1 - self.ssim(img1, img2)
    
# GAN Loss (for the Generator - VAE's Decoder)
class LSGANLoss(nn.Module):
    def __init__(self, label_smoothing=0.2):  # Increased label smoothing
        super(LSGANLoss, self).__init__()
        self.label_smoothing = label_smoothing

    def discriminator_loss(self, real_output, fake_output):
        real_target = 1.0 - self.label_smoothing  # e.g., 0.8
        fake_target = 0.0 + self.label_smoothing  # e.g., 0.2
        real_loss = torch.mean((real_output - real_target) ** 2)
        fake_loss = torch.mean((fake_output - fake_target) ** 2)
        return (real_loss + fake_loss) * 0.5

    def generator_loss(self, fake_output):
        return torch.mean((fake_output - (1.0 - self.label_smoothing)) ** 2)


# --- 2. Build a Combined Criterion Function ---

def combined_criterion(outputs, images, mu, logvar, loss_config, kld_weight,
                      perceptual_criterion=None, lpips_criterion=None, ssim_criterion=None):
    total_loss = 0

    if 'reconstruction' in loss_config and loss_config['reconstruction']['weight'] > 0:
        recon_loss = reconstruction_loss(outputs, images)
        total_loss += loss_config['reconstruction']['weight'] * recon_loss

    if 'kld' in loss_config and loss_config['kld']['weight'] > 0:
        kld = kld_loss(mu, logvar)
        total_loss += kld_weight * kld  # Use dynamic KLD weight

    if 'perceptual' in loss_config and loss_config['perceptual']['weight'] > 0:
        perc_loss = perceptual_criterion(outputs, images)
        total_loss += loss_config['perceptual']['weight'] * perc_loss

    if 'lpips' in loss_config and loss_config['perceptual']['weight'] > 0:
        lpips_loss = lpips_criterion(outputs, images)
        total_loss += loss_config['lpips']['weight'] * lpips_loss

    if 'ssim' in loss_config and loss_config['ssim']['weight'] > 0:
        ssim_loss = ssim_criterion(outputs, images)
        total_loss += loss_config['ssim']['weight'] * ssim_loss

    return total_loss

class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(PatchDiscriminator, self).__init__()
        
        def conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            if use_bn:
                layers.insert(1, nn.InstanceNorm2d(out_channels))
            return nn.Sequential(*layers)

        # PatchGAN-style discriminator with multiple output patches
        self.conv1 = conv_block(input_channels, 64, use_bn=False)  # 64x32x32
        self.conv2 = conv_block(64, 128)                         # 128x16x16
        self.conv3 = conv_block(128, 256)                        # 256x8x8
        self.conv4 = conv_block(256, 512)                        # 512x4x4
        self.conv5 = nn.Conv2d(512, 1, 4, 1, 0)                 # 1x1x1 (patch outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return torch.sigmoid(x)  # Output patch-wise probabilities

In [6]:
# --- 3. Modify the Training Function ---
def train_vae(model, dataloader, valloader, optimizer, loss_config, discriminator=None, 
              discriminator_optimizer=None, num_epochs=20,
              perceptual_criterion=None, lpips_criterion=None, ssim_criterion=None, 
              run_name="default_run"):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gan_loss_fn = LSGANLoss(label_smoothing=0.2)  # Increased label smoothing
    
    output_root_dir = f"jn_outs/{run_name}"
    gt_dir = os.path.join(output_root_dir, "ground_truth")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(output_root_dir, exist_ok=True)
    to_pil = ToPILImage()
    ground_truth_saved = False

    # Save ground truth images
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(valloader):
            if not ground_truth_saved:
                images_denorm = denormalize(images).cpu()
                for i in range(min(10, images_denorm.size(0))):
                    img_tensor = images_denorm[i]
                    img_pil = to_pil(img_tensor)
                    img_pil.save(os.path.join(gt_dir, f"gt_image_{batch_idx * valloader.batch_size + i}.png"))
                if (batch_idx + 1) * valloader.batch_size >= 10:
                    ground_truth_saved = True
                    break

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
        
        # Compute dynamic KLD weight for warm-up (over first 5 epochs)
        kld_weight = max(0.00001, min(loss_config['kld']['weight'] * (epoch / 20.0), loss_config['kld']['weight']))  # Minimum KLD weight of 0.00001
        
        # Accumulate losses for averaging over the epoch
        epoch_recon_loss = 0.0
        epoch_kld_loss = 0.0
        epoch_perceptual_loss = 0.0
        epoch_lpips_loss = 0.0
        epoch_ssim_loss = 0.0
        epoch_gan_g_loss = 0.0
        epoch_gan_d_loss = 0.0
        num_batches = 0

        for i, (images, _) in enumerate(train_loop):
            images = images.to(device)
            
            # VAE/Generator forward pass
            optimizer.zero_grad()
            outputs, mu, logvar = model(images)
            print(f"Batch {i+1} - mu (min, max, mean): {mu.min().item():.4f}, {mu.max().item():.4f}, {mu.mean().item():.4f}")
            print(f"Batch {i+1} - logvar (min, max, mean): {logvar.min().item():.4f}, {logvar.max().item():.4f}, {logvar.mean().item():.4f}")
            # Compute individual losses
            recon_loss = reconstruction_loss(outputs, images) if 'reconstruction' in loss_config else 0.0
            kld = kld_loss(mu, logvar) if 'kld' in loss_config else 0.0
            perc_loss = perceptual_criterion(outputs, images) if 'perceptual' in loss_config and perceptual_criterion else 0.0
            lpips_loss = lpips_criterion(outputs, images) if 'lpips' in loss_config and lpips_criterion else 0.0
            ssim_loss = ssim_criterion(outputs, images) if 'ssim' in loss_config and ssim_criterion else 0.0

            # Calculate combined loss excluding GAN
            total_loss = combined_criterion(outputs, images, mu, logvar, loss_config, kld_weight,
                                            perceptual_criterion=perceptual_criterion,
                                            lpips_criterion=lpips_criterion,
                                            ssim_criterion=ssim_criterion)

            # GAN training if enabled
            n_discriminator_steps = 1
            g_gan_loss = 0.0
            d_loss = 0.0
            if 'gan' in loss_config and loss_config['gan']['weight'] > 0 and discriminator is not None:
                for _ in range(n_discriminator_steps):
                    discriminator_optimizer.zero_grad()
                    real_output = discriminator(images)
                    fake_output = discriminator(outputs.detach())
                    d_loss = gan_loss_fn.discriminator_loss(real_output, fake_output)
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)  # Add gradient clipping
                    discriminator_optimizer.step()

                # Generator GAN loss
                fake_output = discriminator(outputs)
                g_gan_loss = gan_loss_fn.generator_loss(fake_output)
                total_loss += loss_config['gan']['weight'] * g_gan_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Add gradient clipping
            optimizer.step()

            # Accumulate losses for epoch averaging
            epoch_recon_loss += recon_loss.item()
            epoch_kld_loss += kld.item()
            epoch_perceptual_loss += perc_loss.item() if perc_loss != 0 else 0.0
            epoch_lpips_loss += lpips_loss.item() if lpips_loss != 0 else 0.0
            epoch_ssim_loss += ssim_loss.item() if ssim_loss != 0 else 0.0
            epoch_gan_g_loss += g_gan_loss.item() if g_gan_loss != 0 else 0.0
            epoch_gan_d_loss += d_loss.item() if d_loss != 0 else 0.0
            num_batches += 1

            # Update tqdm progress bar with current batch losses (all losses are per-sample, normalized by batch size)
            train_loop.set_postfix({
                'Recon': f'{recon_loss.item():.4f}',
                'KLD': f'{kld.item():.4f}',
                'Perceptual': f'{perc_loss.item():.4f}' if perc_loss != 0 else 'N/A',
                'LPIPS': f'{lpips_loss.item():.4f}' if lpips_loss != 0 else 'N/A',
                'SSIM': f'{ssim_loss.item():.4f}' if ssim_loss != 0 else 'N/A',
                'GAN_G': f'{g_gan_loss.item():.4f}' if g_gan_loss != 0 else 'N/A',
                'GAN_D': f'{d_loss.item():.4f}' if d_loss != 0 else 'N/A',
                'Total': f'{total_loss.item():.4f}',
                'KLD_Weight': f'{kld_weight:.4f}'  # Display the current KLD weight
            })

        # Compute average losses for the epoch (all averages are per-sample)
        avg_recon_loss = epoch_recon_loss / num_batches
        avg_kld_loss = epoch_kld_loss / num_batches
        avg_perceptual_loss = epoch_perceptual_loss / num_batches
        avg_lpips_loss = epoch_lpips_loss / num_batches
        avg_ssim_loss = epoch_ssim_loss / num_batches
        avg_gan_g_loss = epoch_gan_g_loss / num_batches
        avg_gan_d_loss = epoch_gan_d_loss / num_batches

        # Print average losses for the epoch
        print(f"\nEpoch {epoch+1}/{num_epochs} - Average Losses:")
        print(f"  Reconstruction: {avg_recon_loss:.4f}")
        print(f"  KLD: {avg_kld_loss:.4f}")
        print(f"  Perceptual: {avg_perceptual_loss:.4f}")
        print(f"  LPIPS: {avg_lpips_loss:.4f}")
        print(f"  SSIM: {avg_ssim_loss:.4f}")
        print(f"  GAN (Generator): {avg_gan_g_loss:.4f}")
        print(f"  GAN (Discriminator): {avg_gan_d_loss:.4f}\n")

        # Validation and saving
        model.eval()
        if epoch % 2 == 0:
            with torch.no_grad():
                epoch_output_dir = os.path.join(output_root_dir, f"epoch_{epoch}")
                os.makedirs(epoch_output_dir, exist_ok=True)
                for i, (images, _) in enumerate(valloader):
                    images = images.to(device)
                    outputs, _, _ = model(images)
                    outputs_denorm = denormalize(outputs).cpu()
                    for j in range(outputs_denorm.size(0)):
                        img_tensor = outputs_denorm[j]
                        img_pil = to_pil(img_tensor)
                        img_pil.save(os.path.join(epoch_output_dir, f"prediction_{i * valloader.batch_size + j}.png"))

    return None

In [None]:
device = "cuda"

input_channels = 3
image_size = 64
input_dim = input_channels * image_size * image_size
vae = VAE().to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

discriminator = PatchDiscriminator(input_channels=3).to(device)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=5e-6)  # Lowered learning rate

perceptual_criterion = PerceptualLoss(layer='relu3_3').to(device)
lpips_criterion = LPIPSLoss(net='vgg').to(device)
ssim_criterion = SSIMLoss(data_range=1.0).to(device)
dataloader, valloader, _, _ = basic_train_config(vae, 100000, 10, 256)

# Define the loss configuration with increased reconstruction weight
loss_config = {
    'reconstruction': {'weight': 0.5},  # Increased from 0.1
    'kld': {'weight': 0.05},
    'perceptual': {'weight': 0.5},
    'lpips': {'weight': 0.5},
    'ssim': {'weight': 0.2},
    'gan': {'weight': 1.0},
}


# Train the VAE with the combined criterion
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule")

In [11]:
num_epochs = 25 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule")

13


interactive(children=(IntSlider(value=6, description='epoch_idx', max=13), Text(value='25epoch_basic_vae_loss_…

<function __main__.visualize_epoch(epoch_idx, run_name='default_run')>

As we can see, there is still a problem, grid like artifacts. This is a common issue and names as "checkerboard pattern" and is typically caused by the use of transposed convolutions (nn.ConvTranspose2d) in the decoder, combined with the adversarial training dynamics introduced by the GAN loss. As we can see logs above, discriminator loss is too low, which indicates the discriminator can easily distinguish real images from generated ones. This imbalance makes the generator (decoder) work harder to fool the discriminator, often leading to the amplification of artifacts like the checkerboard pattern. First, we will change convtranspose2d with nn.Upsample (A common solution is to replace them with a combination of nearest-neighbor upsampling (nn.Upsample) followed by a regular convolution (nn.Conv2d).) and see if it is the main reason. After that, we will try to balance adversarial training.

In [12]:
class VAE_Decoder(nn.Module):
    def __init__(self):
        super(VAE_Decoder, self).__init__()

        self.fc = nn.Linear(256, 512 * 4 * 4)  # Expanding latent vector

        # Replace ConvTranspose2d with Upsample + Conv2d
        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')  # 4x4 -> 8x8
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.gn1 = nn.GroupNorm(num_groups=32, num_channels=256)

        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')  # 8x8 -> 16x16
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.gn2 = nn.GroupNorm(num_groups=32, num_channels=128)

        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')  # 16x16 -> 32x32
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.gn3 = nn.GroupNorm(num_groups=32, num_channels=64)

        self.up4 = nn.Upsample(scale_factor=2, mode='nearest')  # 32x32 -> 64x64
        self.conv4 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, z):
        z = self.fc(z)  # Convert (batch_size, 256) -> (batch_size, 512 * 4 * 4)
        z = z.view(z.size(0), 512, 4, 4)  # Reshape to 4D tensor for ConvTranspose

        z = F.leaky_relu(self.gn1(self.conv1(self.up1(z))), 0.2)
        z = F.leaky_relu(self.gn2(self.conv2(self.up2(z))), 0.2)
        z = F.leaky_relu(self.gn3(self.conv3(self.up3(z))), 0.2)
        pre_activation = self.conv4(self.up4(z))
        print("Pre-activation (min, max, mean):", pre_activation.min().item(), pre_activation.max().item(), pre_activation.mean().item())
        z = torch.tanh(pre_activation)

        return z

In [None]:
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule_updated_vae")

In [None]:
num_epochs = 6 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule_updated_vae")

3


interactive(children=(IntSlider(value=1, description='epoch_idx', max=3), Text(value='25epoch_basic_vae_loss_b…

<function __main__.visualize_epoch(epoch_idx, run_name='default_run')>

From first 2 epochs, we can say that problem decreased, but still appears. We will switch the decoder's group norm to layer norm, also we will add a final conv layer which aims to smooth after the last upsampling step. This can be a small 3x3 convolution with a stride of 1, and may reduce any residual high-frequency artifacts like we are facing.

In [None]:
class VAE_Decoder(nn.Module):
    def __init__(self):
        super(VAE_Decoder, self).__init__()

        self.fc = nn.Linear(256, 512 * 4 * 4)  # Expanding latent vector

        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')  # 4x4 -> 8x8
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.ln1 = nn.LayerNorm([256, 8, 8])  # Normalize over (C, H, W)

        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')  # 8x8 -> 16x16
        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.ln2 = nn.LayerNorm([128, 16, 16])

        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')  # 16x16 -> 32x32
        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.ln3 = nn.LayerNorm([64, 32, 32])

        self.up4 = nn.Upsample(scale_factor=2, mode='nearest')  # 32x32 -> 64x64
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.ln4 = nn.LayerNorm([64, 64, 64])
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, z):
        z = self.fc(z)
        z = z.view(z.size(0), 512, 4, 4)

        z = F.leaky_relu(self.ln1(self.conv1(self.up1(z))), 0.2)
        z = F.leaky_relu(self.ln2(self.conv2(self.up2(z))), 0.2)
        z = F.leaky_relu(self.ln3(self.conv3(self.up3(z))), 0.2)
        z = F.leaky_relu(self.ln4(self.conv4(self.up4(z))), 0.2)
        pre_activation = self.final_conv(z)
        print("Pre-activation (min, max, mean):", pre_activation.min().item(), pre_activation.max().item(), pre_activation.mean().item())
        z = torch.tanh(pre_activation)

        return z

In [None]:
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule_updated_vae_Wlayernorm")

In [20]:
num_epochs = 6 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule_updated_vae_decoder_wlayernorm")

3


interactive(children=(IntSlider(value=1, description='epoch_idx', max=3), Text(value='25epoch_basic_vae_loss_b…

<function __main__.visualize_epoch(epoch_idx, run_name='default_run')>

Lets train better Gan centric loss balanced training. 

- Increase label smoothing
- Add gradient penalty to the discriminator. Quick note: Dicriminator can be refered as critic at other resources. 
- Decrease learning rate of discriminator, Again. 

In [15]:
class LSGANLoss(nn.Module):
    def __init__(self, label_smoothing=0.3):  # Increased from 0.2
        super(LSGANLoss, self).__init__()
        self.label_smoothing = label_smoothing

    def discriminator_loss(self, real_output, fake_output):
        real_target = 1.0 - self.label_smoothing  # e.g., 0.7
        fake_target = 0.0 + self.label_smoothing  # e.g., 0.3
        real_loss = torch.mean((real_output - real_target) ** 2)
        fake_loss = torch.mean((fake_output - fake_target) ** 2)
        return (real_loss + fake_loss) * 0.5

    def generator_loss(self, fake_output):
        return torch.mean((fake_output - (1.0 - self.label_smoothing)) ** 2)

In [None]:
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.0005)  # Reduced from 0.001
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-6)  # Reduced from 5e-6

In [16]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=real_samples.device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = discriminator(interpolates)
    fake = torch.ones(d_interpolates.size(), device=real_samples.device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Update the training loop to include gradient penalty
def train_vae(model, dataloader, valloader, optimizer, loss_config, discriminator=None, 
              discriminator_optimizer=None, num_epochs=20,
              perceptual_criterion=None, lpips_criterion=None, ssim_criterion=None, 
              run_name="default_run"):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gan_loss_fn = LSGANLoss(label_smoothing=0.3)  # Increased label smoothing
    
    output_root_dir = f"jn_outs/{run_name}"
    gt_dir = os.path.join(output_root_dir, "ground_truth")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(output_root_dir, exist_ok=True)
    to_pil = ToPILImage()
    ground_truth_saved = False

    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(valloader):
            if not ground_truth_saved:
                images_denorm = denormalize(images).cpu()
                for i in range(min(10, images_denorm.size(0))):
                    img_tensor = images_denorm[i]
                    img_pil = to_pil(img_tensor)
                    img_pil.save(os.path.join(gt_dir, f"gt_image_{batch_idx * valloader.batch_size + i}.png"))
                if (batch_idx + 1) * valloader.batch_size >= 10:
                    ground_truth_saved = True
                    break

    for epoch in range(num_epochs):
        model.train()
        train_loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
        
        kld_weight = max(0.00001, min(loss_config['kld']['weight'] * (epoch / 20.0), loss_config['kld']['weight']))
        
        epoch_recon_loss = 0.0
        epoch_kld_loss = 0.0
        epoch_perceptual_loss = 0.0
        epoch_lpips_loss = 0.0
        epoch_ssim_loss = 0.0
        epoch_gan_g_loss = 0.0
        epoch_gan_d_loss = 0.0
        num_batches = 0

        for i, (images, _) in enumerate(train_loop):
            images = images.to(device)
            
            optimizer.zero_grad()
            outputs, mu, logvar = model(images)
            print(f"Batch {i+1} - mu (min, max, mean): {mu.min().item():.4f}, {mu.max().item():.4f}, {mu.mean().item():.4f}")
            print(f"Batch {i+1} - logvar (min, max, mean): {logvar.min().item():.4f}, {logvar.max().item():.4f}, {logvar.mean().item():.4f}")

            recon_loss = reconstruction_loss(outputs, images) if 'reconstruction' in loss_config else 0.0
            kld = kld_loss(mu, logvar) if 'kld' in loss_config else 0.0
            perc_loss = perceptual_criterion(outputs, images) if 'perceptual' in loss_config and perceptual_criterion else 0.0
            lpips_loss = lpips_criterion(outputs, images) if 'lpips' in loss_config and lpips_criterion else 0.0
            ssim_loss = ssim_criterion(outputs, images) if 'ssim' in loss_config and ssim_criterion else 0.0

            total_loss = combined_criterion(outputs, images, mu, logvar, loss_config, kld_weight,
                                            perceptual_criterion=perceptual_criterion,
                                            lpips_criterion=lpips_criterion,
                                            ssim_criterion=ssim_criterion)

            n_discriminator_steps = 1
            g_gan_loss = 0.0
            d_loss = 0.0
            if 'gan' in loss_config and loss_config['gan']['weight'] > 0 and discriminator is not None:
                for _ in range(n_discriminator_steps):
                    discriminator_optimizer.zero_grad()
                    real_output = discriminator(images)
                    fake_output = discriminator(outputs.detach())
                    d_loss = gan_loss_fn.discriminator_loss(real_output, fake_output)
                    # Add gradient penalty
                    gp = compute_gradient_penalty(discriminator, images, outputs.detach())
                    d_loss += 10.0 * gp  # Weight of 10.0 for gradient penalty
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                    discriminator_optimizer.step()

                fake_output = discriminator(outputs)
                g_gan_loss = gan_loss_fn.generator_loss(fake_output)
                total_loss += loss_config['gan']['weight'] * g_gan_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_recon_loss += recon_loss.item()
            epoch_kld_loss += kld.item()
            epoch_perceptual_loss += perc_loss.item() if perc_loss != 0 else 0.0
            epoch_lpips_loss += lpips_loss.item() if lpips_loss != 0 else 0.0
            epoch_ssim_loss += ssim_loss.item() if ssim_loss != 0 else 0.0
            epoch_gan_g_loss += g_gan_loss.item() if g_gan_loss != 0 else 0.0
            epoch_gan_d_loss += d_loss.item() if d_loss != 0 else 0.0
            num_batches += 1

            train_loop.set_postfix({
                'Recon': f'{recon_loss.item():.4f}',
                'KLD': f'{kld.item():.4f}',
                'Perceptual': f'{perc_loss.item():.4f}' if perc_loss != 0 else 'N/A',
                'LPIPS': f'{lpips_loss.item():.4f}' if lpips_loss != 0 else 'N/A',
                'SSIM': f'{ssim_loss.item():.4f}' if ssim_loss != 0 else 'N/A',
                'GAN_G': f'{g_gan_loss.item():.4f}' if g_gan_loss != 0 else 'N/A',
                'GAN_D': f'{d_loss.item():.4f}' if d_loss != 0 else 'N/A',
                'Total': f'{total_loss.item():.4f}',
                'KLD_Weight': f'{kld_weight:.4f}'
            })

        avg_recon_loss = epoch_recon_loss / num_batches
        avg_kld_loss = epoch_kld_loss / num_batches
        avg_perceptual_loss = epoch_perceptual_loss / num_batches
        avg_lpips_loss = epoch_lpips_loss / num_batches
        avg_ssim_loss = epoch_ssim_loss / num_batches
        avg_gan_g_loss = epoch_gan_g_loss / num_batches
        avg_gan_d_loss = epoch_gan_d_loss / num_batches

        print(f"\nEpoch {epoch+1}/{num_epochs} - Average Losses:")
        print(f"  Reconstruction: {avg_recon_loss:.4f}")
        print(f"  KLD: {avg_kld_loss:.4f}")
        print(f"  Perceptual: {avg_perceptual_loss:.4f}")
        print(f"  LPIPS: {avg_lpips_loss:.4f}")
        print(f"  SSIM: {avg_ssim_loss:.4f}")
        print(f"  GAN (Generator): {avg_gan_g_loss:.4f}")
        print(f"  GAN (Discriminator): {avg_gan_d_loss:.4f}\n")

        model.eval()
        if epoch % 2 == 0:
            with torch.no_grad():
                epoch_output_dir = os.path.join(output_root_dir, f"epoch_{epoch}")
                os.makedirs(epoch_output_dir, exist_ok=True)
                for i, (images, _) in enumerate(valloader):
                    images = images.to(device)
                    outputs, _, _ = model(images)
                    outputs_denorm = denormalize(outputs).cpu()
                    for j in range(outputs_denorm.size(0)):
                        img_tensor = outputs_denorm[j]
                        img_pil = to_pil(img_tensor)
                        img_pil.save(os.path.join(epoch_output_dir, f"prediction_{i * valloader.batch_size + j}.png"))

    return None

In [None]:
train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard")

In [None]:
num_epochs = 20 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard")

10


interactive(children=(IntSlider(value=5, description='epoch_idx', max=10), Text(value='25epoch_basic_vae_loss_…

<function __main__.visualize_epoch(epoch_idx, run_name='default_run')>

As we see above, we still suffer from grid like outputs. Lets ignore GAN loss and see what happens.

In [None]:
loss_config = {
    'reconstruction': {'weight': 0.5},
    'kld': {'weight': 0.05},
    'perceptual': {'weight': 0.5},
    'lpips': {'weight': 0.5},
    'ssim': {'weight': 0.2},
    #'gan': {'weight': 0.0},
}

train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard_no_gan")

In [None]:
num_epochs = 6 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard_no_gan")

3


interactive(children=(IntSlider(value=1, description='epoch_idx', max=3), Text(value='25epoch_basic_vae_loss_b…

<function __main__.visualize_epoch(epoch_idx, run_name='default_run')>

And problem persists, the problem probably caused by the scale of the model, and information losses during training. For this reason, we will scale our model, also we will ad skip connections.

In [None]:
def swish(x):
    return x * torch.sigmoid(x)
class ResnetBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int = None, num_groups: int = 1):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels, eps=1e-6, affine=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.nin_shortcut = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm1(x)
        h = swish(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = swish(h)
        h = self.conv2(h)
        if self.nin_shortcut is not None:
            x = self.nin_shortcut(x)
        return x + h

class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pad = (0, 1, 0, 1)
        x = nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.conv(x)
        return x

class Upsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        x = self.conv(x)
        return x
class VAE_Encoder(nn.Module):
    def __init__(self, base_channels=64, latent_dim=256):
        super(VAE_Encoder, self).__init__()
        
        # Initial convolution to map input channels to base_channels
        self.init_conv = nn.Conv2d(3, base_channels, kernel_size=3, stride=1, padding=1)
        
        # Stage 1: 64x64 -> 32x32, 64 channels
        self.res1 = ResnetBlock(base_channels, base_channels, num_groups=32)
        self.down1 = Downsample(base_channels)
        
        # Stage 2: 32x32 -> 16x16, 128 channels
        self.res2 = ResnetBlock(base_channels, base_channels * 2, num_groups=32)
        self.down2 = Downsample(base_channels * 2)
        
        # Stage 3: 16x16 -> 8x8, 256 channels
        self.res3 = ResnetBlock(base_channels * 2, base_channels * 4, num_groups=32)
        self.down3 = Downsample(base_channels * 4)
        
        # Stage 4: 8x8 -> 4x4, 512 channels
        self.res4 = ResnetBlock(base_channels * 4, base_channels * 8, num_groups=32)
        self.down4 = Downsample(base_channels * 8)
        
        # Final ResnetBlock at 4x4 resolution
        self.res5 = ResnetBlock(base_channels * 8, base_channels * 8, num_groups=32)
        
        # Flatten and map to latent space
        self.fc_mu = nn.Linear((base_channels * 8) * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear((base_channels * 8) * 4 * 4, latent_dim)

        # Initialize weights
        nn.init.normal_(self.fc_mu.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.fc_logvar.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.fc_mu.bias)
        nn.init.zeros_(self.fc_logvar.bias)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Input: (B, 3, 64, 64)
        x = self.init_conv(x)  # (B, 64, 64, 64)
        
        x = self.res1(x)  # (B, 64, 64, 64)
        x = self.down1(x)  # (B, 64, 32, 32)
        
        x = self.res2(x)  # (B, 128, 32, 32)
        x = self.down2(x)  # (B, 128, 16, 16)
        
        x = self.res3(x)  # (B, 256, 16, 16)
        x = self.down3(x)  # (B, 256, 8, 8)
        
        x = self.res4(x)  # (B, 512, 8, 8)
        x = self.down4(x)  # (B, 512, 4, 4)
        
        x = self.res5(x)  # (B, 512, 4, 4)
        
        # Flatten
        x = x.view(x.size(0), -1)  # (B, 512*4*4)
        
        # Predict mu and logvar
        mu = self.fc_mu(x)  # (B, 256)
        logvar = self.fc_logvar(x)  # (B, 256)
        z = self.reparameterize(mu, logvar)  # (B, 256)
        
        return z, mu, logvar
    
class VAE_Decoder(nn.Module):
    def __init__(self, base_channels=64, latent_dim=256):
        super(VAE_Decoder, self).__init__()

        # Map latent space to initial feature map
        self.fc = nn.Linear(latent_dim, (base_channels * 8) * 4 * 4)
        
        # Stage 1: 4x4 -> 4x4, 512 channels
        self.res1 = ResnetBlock(base_channels * 8, base_channels * 8, num_groups=32)
        
        # Stage 2: 4x4 -> 8x8, 512 -> 256 channels
        self.up2 = Upsample(base_channels * 8)
        self.res2 = ResnetBlock(base_channels * 8, base_channels * 4, num_groups=32)
        
        # Stage 3: 8x8 -> 16x16, 256 -> 128 channels
        self.up3 = Upsample(base_channels * 4)
        self.res3 = ResnetBlock(base_channels * 4, base_channels * 2, num_groups=32)
        
        # Stage 4: 16x16 -> 32x32, 128 -> 64 channels
        self.up4 = Upsample(base_channels * 2)
        self.res4 = ResnetBlock(base_channels * 2, base_channels, num_groups=32)
        
        # Stage 5: 32x32 -> 64x64, 64 channels
        self.up5 = Upsample(base_channels)
        self.res5 = ResnetBlock(base_channels, base_channels, num_groups=32)
        
        # Final convolution to map to 3 channels
        self.final_conv = nn.Conv2d(base_channels, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, z):
        # Input: (B, 256)
        z = self.fc(z)  # (B, 512*4*4)
        z = z.view(z.size(0), -1, 4, 4)  # (B, 512, 4, 4)
        
        z = self.res1(z)  # (B, 512, 4, 4)
        
        z = self.up2(z)  # (B, 512, 8, 8)
        z = self.res2(z)  # (B, 256, 8, 8)
        
        z = self.up3(z)  # (B, 256, 16, 16)
        z = self.res3(z)  # (B, 128, 16, 16)
        
        z = self.up4(z)  # (B, 128, 32, 32)
        z = self.res4(z)  # (B, 64, 32, 32)
        
        z = self.up5(z)  # (B, 64, 64, 64)
        z = self.res5(z)  # (B, 64, 64, 64)
        
        pre_activation = self.final_conv(z)  # (B, 3, 64, 64)
        print("Pre-activation (min, max, mean):", pre_activation.min().item(), pre_activation.max().item(), pre_activation.mean().item())
        z = torch.tanh(pre_activation)
        
        return z
    
class VAE(nn.Module):
    def __init__(self, base_channels=64, latent_dim=256):
        super().__init__()
        self.encoder = VAE_Encoder(base_channels=base_channels, latent_dim=latent_dim)
        self.decoder = VAE_Decoder(base_channels=base_channels, latent_dim=latent_dim)

    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        x = self.decoder(z)
        return x, mu, logvar

In [None]:
loss_config = {
    'reconstruction': {'weight': 0.5},
    'kld': {'weight': 0.05},
    'perceptual': {'weight': 0.5},
    'lpips': {'weight': 0.5},
    'ssim': {'weight': 0.2},
    #'gan': {'weight': 0.0},
}

train_vae(vae, dataloader, valloader, vae_optimizer, loss_config, discriminator, discriminator_optimizer,
          num_epochs=25,
          perceptual_criterion=perceptual_criterion,
          lpips_criterion=lpips_criterion,
          ssim_criterion=ssim_criterion,
          run_name="25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard_no_gan_scaledresnet")

In [None]:
num_epochs = 6 # Example
# ... (rest of your code) ...

# Calculate the number of saved epochs
num_saved_epochs = num_epochs // 2 if num_epochs % 2 == 0 else num_epochs // 2 + 1
print(num_saved_epochs)
interact(visualize_epoch, epoch_idx=(0, num_saved_epochs , 1),run_name = "25epoch_basic_vae_loss_balanced_stable_kld_schedule_no_checkerboard_no_gan_scaledresnet")

### You can use most advanced training with main.py, this notebook will be completed in the future but I believe that this is enough for you to understand whole project. 