# Variational Autoencoders 

## Architecture

Architecturally, the difference between a VAE and an AE is that while the encoder in the AE directly outputs the encoded image, in a VAE it outputs mean and stdv values from which we can sample 

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvolutionalVAE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_channels: int,
        conv_dim: int,
        latent_dim: int,
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(n_channels, conv_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(conv_dim),
            nn.Conv2d(conv_dim, 2 * conv_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * conv_dim),
            nn.Conv2d(2 * conv_dim, 4 * conv_dim, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(4 * conv_dim),
            nn.Flatten(),
            nn.Linear(conv_dim * 4 * input_dim[1] // 8 * input_dim[2] // 8, 2 * latent_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 4 * conv_dim * input_dim[1] // 8 * input_dim[2] // 8),
            nn.Unflatten(1, (4 * conv_dim, input_dim[1] // 8, input_dim[2] // 8)),
            nn.ConvTranspose2d(4 * conv_dim, 2 * conv_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * conv_dim),
            nn.ConvTranspose2d(2 * conv_dim, conv_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(conv_dim),
            nn.ConvTranspose2d(conv_dim, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )
        
    def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Sample from the latent space using the reparameterization trick.

        Args:
            mu: mean of the latent space
            logvar: log variance of the latent space

        Returns:
            z: sampled latent space
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the VAE.

        Args:
            x: input data

        Returns:
            x_hat: reconstructed data
        """
        encoded = self.encoder(x)
        mu, logvar = encoded.split(self.latent_dim, dim=1)
        z = self.sample(mu, logvar)
        return self.decoder(z), z

In [62]:
# Checking that dimenionality is correct
vae = ConvolutionalVAE(
    input_dim=(3, 32, 32),
    n_channels=3,
    conv_dim=96,
    latent_dim=128,
)

random_data = torch.randn(4, 3, 32, 32)
x_hat = vae(random_data)

print("x_hat shape", x_hat.shape)

x_hat shape torch.Size([4, 3, 32, 32])


## Objective function

So we've adjusted the architecture and added in the sampling/reparameterization trick to allow the flow of gradients. What's left?

the other difference between the VAE and the AE is the loss function. our loss term consists of two parts now, reconstruction and KL
divergence of the latent distribution from a standard normal distribution

In [63]:
def validate(model, reconstruction_loss_func, kl_loss_func, valid_dl):
    model.eval()
    with torch.no_grad():
        tot_loss = 0.
        count = 0
        for xb, _ in valid_dl:
            pred = model(xb)
            recon_loss = reconstruction_loss_func(pred, xb)
            kl_loss = kl_loss_func(pred)
            loss = recon_loss + kl_loss
            tot_loss += loss.item()
            count += len(xb)

    return tot_loss / count, recon_loss / count, kl_loss / count

def fit(
    epochs, 
    model,
    reconstruction_loss_func, 
    kl_loss_func, 
    opt, 
    train_dl, 
    valid_dl, 
):
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        model.train()
        for xb,yb in train_dl:
            x_hat, z = model(xb)
            recon_loss = reconstruction_loss_func(x_hat, xb)
            kl_loss = kl_loss_func(z, torch.randn_like(z))
            loss = recon_loss + kl_loss
            loss.backward()
            opt.step()
            opt.zero_grad()

        tot_loss_count = validate(model, reconstruction_loss_func, valid_dl)
        print(f"Validation loss: {tot_loss_count}")

## Data boiler plate

In [64]:
import pickle
import numpy as np
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

all_batches_data = []
all_batches_labels = []

for i in range(1, 6):
    with open(f'data/cifar-10-batches-py/data_batch_{i}', 'rb') as f:
        dataset_dict = pickle.load(f, encoding='bytes')
        all_batches_data.append(dataset_dict[b'data'])
        all_batches_labels.append(dataset_dict[b'labels'])

stacked_data = np.vstack(all_batches_data)
stacked_labels = np.hstack(all_batches_labels)
data = torch.tensor(stacked_data, dtype=torch.float32).view(-1, 3, 32, 32).to(device) / 255.
labels = torch.tensor(stacked_labels, dtype=torch.long).to(device)

split_idx = int(0.8 * len(data))

x_train, x_valid = data[:split_idx], data[split_idx:]
y_train, y_valid = labels[:split_idx], labels[split_idx:]


In [65]:
from torch.utils.data import Dataset, DataLoader

class CIFARCustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

train_ds = CIFARCustomDataset(x_train, y_train)
valid_ds = CIFARCustomDataset(x_valid, y_valid)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=True)

## Train

In [66]:
model = ConvolutionalVAE(
    input_dim=(3, 32, 32),
    n_channels=3,
    conv_dim=96,
    latent_dim=128,
)

reconstruction_loss_func = nn.MSELoss()
kl_loss_func = nn.KLDivLoss(reduction='batchmean')
opt = torch.optim.Adam(model.parameters(), lr=5e-4)

model.to(device)

fit(10, model, reconstruction_loss_func, kl_loss_func, opt, train_dl, valid_dl)

Epoch 1/10


TypeError: KLDivLoss.forward() missing 1 required positional argument: 'target'