# Deep Learning

# Tutorial 19: Variational Autoencoder 

In this tutorial, we will cover:

- Creating images with a variational autoencoder

Prerequisites:

- Python, PyTorch, Deep Learning Training, Stochastic Gradient Descent

My contact:

- Niklas Beuter (niklas.beuter@th-luebeck.de)

Course:

- Slides and notebooks will be available at https://lernraum.th-luebeck.de/course/view.php?id=5383

## Expected Outcomes
* Understand the feature generation process of the encoder
* Understand the use of the decoder and why it is not necessary during inference

# Introduction to Variational Autoencoders (VAE)

Variational Autoencoders (VAEs) are a class of generative models that are widely used for tasks such as image generation, data compression, and anomaly detection. Introduced by Kingma and Welling in 2013, VAEs combine techniques from deep learning and Bayesian inference to model complex data distributions.

## Key Concepts

### Autoencoders
An autoencoder is a type of neural network designed to learn a compressed representation of data. It consists of two main parts:
- **Encoder**: This part of the network compresses the input data into a lower-dimensional latent representation.
- **Decoder**: This part reconstructs the input data from the latent representation.

### Variational Inference
Variational inference is a technique in Bayesian statistics that approximates probability densities through optimization. Instead of computing exact posterior distributions, which can be computationally expensive, variational inference approximates these distributions with a simpler, parameterized distribution.

### Latent Space
In a VAE, the latent space is a continuous, multidimensional space where each point represents a possible compressed representation of the input data. The VAE is trained to ensure that similar inputs are mapped to nearby points in the latent space.

## How VAE Works

### Encoder
The encoder network takes an input $ x $ and maps it to a mean $ \mu $ and a standard deviation $ \sigma $ of a Gaussian distribution in the latent space. This can be expressed as:
$$ z \sim \mathcal{N}(\mu(x), \sigma(x)) $$
where $ z $ is the latent vector.

### Reparameterization Trick
To enable backpropagation through the sampling process, VAEs use the reparameterization trick. Instead of sampling $ z $ directly from $ \mathcal{N}(\mu, \sigma) $, we sample $ \epsilon $ from a standard normal distribution and compute $ z $ as:
$$ z = \mu + \sigma \cdot \epsilon \$$
where $ \epsilon \sim \mathcal{N}(0, 1) $.

### Decoder
The decoder network takes the latent vector $ z $ and reconstructs the input data $ \hat{x} $. The goal is for $ \hat{x} $ to be as close as possible to the original input $ x $.

### Loss Function
The VAE loss function consists of two parts:
- **Reconstruction Loss**: Measures how well the decoder reconstructs the input data.
- **KL Divergence**: Regularizes the latent space by ensuring that the learned distribution $ q(z|x) $ is close to the prior $ p(z) $, typically a standard normal distribution.

The combined loss is given by:
$$ \mathcal{L} = \text{Reconstruction Loss} + \text{KL Divergence} $$

## Applications of VAE
- **Image Generation**: VAEs can generate new images by sampling from the latent space and decoding.
- **Data Compression**: VAEs can learn efficient representations of data, useful for compression.
- **Anomaly Detection**: VAEs can identify anomalies by reconstructing data and measuring reconstruction error.

## Conclusion
Variational Autoencoders are powerful tools for learning compact, meaningful representations of data. By combining deep learning with probabilistic inference, VAEs enable various applications in generative modeling and beyond.

For more detailed implementations and advanced topics, exploring the original paper by Kingma and Welling (2013) and subsequent literature is recommended.


## Demonstration

The following code loads data from [MNIST](https://github.com/cvdfoundation/mnist), which are a lot of handdrawings from the digits 0-9.

![MNIST Image](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)

Idea is to learn an encoder, which is able to compress the input data to a low dimensional space (the latent space). As example we just use 20 latents, which are able to reconstruct the data again. 

At the end, we demonstrate that the latent space is continous, which allows to create a random vector and create a number out of it. We can even interpolate between numbers. 

In [None]:
!pip install torch torchvision ipywidgets torchviz gdown

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
import numpy as np

import ipywidgets as widgets
from IPython.display import display

## Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Hyperparameter
batch_size = 128
learning_rate = 1e-3
num_epochs = 1
latent_dim = 20
image_size = 784  # 28*28 für MNIST

# Make sure a result directory exists
output_dir = os.path.join(os.getcwd(), 'results')
os.makedirs(output_dir, exist_ok=True)

# Prepare data, here we just load MNIST data as it is (digits from 0-9)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# VAE-Modell
class VAE(nn.Module):
    def __init__(self, image_size, h_dim, z_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        # following layer fc2 and fc3 are trained in parallel to output mu and logvar as variables for a probability distribution
        self.fc2 = nn.Linear(h_dim, z_dim)  # mu layer
        self.fc3 = nn.Linear(h_dim, z_dim)  # logvar layer (the log variance is used instead of the standard deviation directly due to numerical reasons)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)

    # This is needed to be able to caluclate the Backpropagation as we do not have a single value, but a probability distribution as output
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar) # Trick to calculate the standard deviation out of logvar
        eps = torch.randn_like(std) # returns values with the same size like std, but chosen from a standard normal distribution (mean=0, std=1)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))

    def forward(self, x):
        mu, logvar = self.encode(x) # return both, mean and the logvar 
        z = self.reparameterize(mu, logvar) # calculate the standard deviation
        return self.decode(z), mu, logvar

# Verlustfunktion
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Modell, Optimizer
vae = VAE(image_size=image_size, h_dim=400, z_dim=latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

def saveImgs(model, latent_dim, epoch, output_dir):
    with torch.no_grad():
            sample = torch.randn(64, latent_dim).to(device)
            results = model.decode(sample).cpu()

            # Erstelle eine Collage aus Zufallsvektoren und generierten Bildern
            for i in range(64):
                fig, ax = plt.subplots(1, 2, figsize=(8, 4))
                ax[0].imshow(sample[i].cpu().numpy().reshape(1, -1), cmap='viridis', aspect='auto')
                ax[0].set_title('Random Vector')
                ax[0].axis('off')
                
                ax[1].imshow(results[i].view(28, 28), cmap='gray')
                ax[1].set_title('Generated Image')
                ax[1].axis('off')
                            
                plt.savefig(os.path.join(output_dir, f'sample_{epoch}_{i}.png'))
                plt.close(fig)
            #save_image(sample.view(64, 1, 28, 28), os.path.join(output_dir, f'sample_{epoch}.png'))

# Generate new images
def generate_images(model, num_images=64):
    with torch.no_grad():
        sample = torch.randn(num_images, latent_dim).to(device)
        results = model.decode(sample).cpu()

        for i in range(num_images):
            fig, ax = plt.subplots(1, 2, figsize=(8, 4))
            ax[0].imshow(results[i].cpu().numpy().reshape(1, -1), cmap='viridis', aspect='auto')
            ax[0].set_title('Random Vector')
            ax[0].axis('off')
            
            ax[1].imshow(results[i].view(28, 28), cmap='gray')
            ax[1].set_title('Generated Image')
            ax[1].axis('off')

            plt.savefig(os.path.join(output_dir, f'generated_images_{i}.png'))
            plt.close(fig)

        #save_image(sample.view(num_images, 1, 28, 28), os.path.join(output_dir, 'generated_images.png'))


In [None]:
# Save latent representations (for interpolation later)
latent_representations = {i: [] for i in range(10)}

# Training
for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        # Save latent representations (for interpolation later)
        z = vae.reparameterize(mu, logvar)
        for i, label in enumerate(labels):
            latent_representations[label.item()].append(z[i].detach().cpu().numpy())

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss / len(train_loader.dataset):.4f}')

    # Save example images
    #saveImgs(vae, latent_dim, epoch, output_dir)


#generate_images(vae)

In [None]:
# Funktion zur Interpolation zwischen zwei latenten Vektoren
def interpolate_vectors(v1, v2, num_steps=10):
    vectors = []
    for alpha in np.linspace(0, 1, num_steps):
        interpolated = (1 - alpha) * v1 + alpha * v2
        vectors.append(interpolated)
    return torch.stack(vectors)

# Generierung von interpolierten Bildern
def generate_interpolated_images_between_digits(model, digit1, digit2, num_interpolations=10):
    with torch.no_grad():
        z1 = torch.tensor(latent_representations[digit1][-1]).to(device)  # Wähle einen latenten Vektor für die erste Zahl
        z2 = torch.tensor(latent_representations[digit2][-1]).to(device)  # Wähle einen latenten Vektor für die zweite Zahl
        interpolated_vectors = interpolate_vectors(z1, z2, num_interpolations).to(device)
        recon_images = model.decode(interpolated_vectors).cpu()

        if not os.path.exists('./results'):
            os.makedirs('./results')

        for i in range(num_interpolations):
            fig, ax = plt.subplots(1, 2, figsize=(8, 4))
            random_vector = interpolated_vectors[i].cpu().numpy().reshape(1, -1)
            ax[0].imshow(random_vector, cmap='viridis', aspect='auto')
            ax[0].set_title('Interpolated Vector')
            ax[0].axis('off')
            
            ax[1].imshow(recon_images[i].view(28, 28), cmap='gray')
            ax[1].set_title('Generated Image')
            ax[1].axis('off')
            
            plt.savefig(f'./results/interpolated_image_{digit1}_to_{digit2}_{i}.png')
            plt.close(fig)

generate_interpolated_images_between_digits(vae, digit1=4, digit2=7, num_interpolations=10)

In [None]:
# Funktion zur Interpolation zwischen zwei latenten Vektoren
def interpolate_vectors2(v1, v2, alpha):
    return (1 - alpha) * v1 + alpha * v2

# Initialisiere zwei zufällige latente Vektoren
z1 = torch.randn(latent_dim).to(device)
z2 = torch.randn(latent_dim).to(device)

# Erstelle die Interpolationsfunktion
def plot_interpolation(alpha):
    z_interpolated = interpolate_vectors2(z1, z2, alpha)
    recon_image = vae.decode(z_interpolated).cpu().detach().numpy().reshape(28, 28)
    
    plt.imshow(recon_image, cmap='gray')
    plt.title(f'Interpolation: alpha={alpha:.2f}')
    plt.axis('off')
    plt.show()

# Erstelle den Slider
alpha_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5, description='Alpha:')
a_s = widgets.FloatSlider(min=-1.0, max=1.0, step=0.01, value=0.5, description='a:')

widgets.interact(plot_interpolation, alpha=alpha_slider)


In [None]:
# Erstelle die Interpolationsfunktion für den latenten Vektor
def plot_latent_vector(**latent_values):
    z = torch.tensor([latent_values[f'z{i}'] for i in range(latent_dim)], dtype=torch.float32).to(device)
    recon_image = vae.decode(z).cpu().detach().numpy().reshape(28, 28)

    plt.imshow(recon_image, cmap='gray')
    plt.title(f'Latent Vector: {latent_values}')
    plt.axis('off')
    plt.show()

# Erstelle die Slider für den latenten Vektor
sliders = [widgets.FloatSlider(min=-3.0, max=3.0, step=0.1, value=0.0, description=f'z{i}') for i in range(latent_dim)]

# Erstelle die Interaktivität
ui = widgets.VBox(sliders)
out = widgets.interactive_output(plot_latent_vector, {f'z{i}': sliders[i] for i in range(latent_dim)})

display(ui, out)

## Save and load the model to save time

In [None]:
torch.save({
        'epoch': epoch,
        'model_state_dict': vae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
    }, f'vae_checkpoint_epoch_{epoch}.pth')

In [None]:
# Modell laden
checkpoint = torch.load('vae_checkpoint_epoch_99.pth')
vae.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['loss']

#vae.eval()  # Setze das Modell in den Evaluierungsmodus, falls erforderlich
# oder
vae.train()  # Setze das Modell in den Trainingsmodus, falls erforderlich

In [None]:
print(vae)

## Generate faces

We are using the celebA dataset. Downloading it is not always possible as it is hosted on gdrive.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import os

# Hyperparameter
batch_size = 128
learning_rate = 1e-3
num_epochs = 1
latent_dim = 20
image_size = 64  # 64x64 für CelebA-Bilder

# Datenvorbereitung
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, 0

# Verzeichnis, in das die Bilder entpackt wurden
data_dir = './data/celeba/img_align_celeba'  # Ersetze 'path_to_extract_directory' durch den tatsächlichen Pfad

train_dataset = CelebADataset(root_dir=data_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# VAE-Modell
class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=256, z_dim=20):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(h_dim * 4 * 4, z_dim)
        self.fc2 = nn.Linear(h_dim * 4 * 4, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(h_dim, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        return self.fc1(h), self.fc2(h)

    def reparameterize(self, mu, logvar):
        # There is also a function in pytorch for that
        # std = torch.exp(0.5 * logvar)
        # return torch.distributions.Normal(loc=mu, scale=std).rsample()  # rsample allows pathwise derivatives, i.e. implements the reparametrization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std) # returns values with the same size like std, but chosen from a standard normal distribution (mean=0, std=1)
        return mu + eps * std

    def decode(self, z):
        h = self.fc3(z)
        h = h.view(h.size(0), 256, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Verlustfunktion
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Modell, Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE(image_channels=3, h_dim=256, z_dim=latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils

# Funktion zum Anzeigen von Trainingsbildern
def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Anzeige der Trainingsbilder vor dem Training
real_batch = next(iter(train_loader))
images, labels = real_batch  # Entpacke die Bilder und Labels
plt.figure(figsize=(8, 16))
plt.axis("off")
plt.title("Training Images")
imshow(vutils.make_grid(images[:32], padding=2, normalize=False))

In [None]:
# Training
for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        # Stelle sicher, dass die Zielbilder im Bereich [0, 1] liegen
        loss = loss_function(recon_batch, (data + 1) / 2, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss / len(train_loader.dataset):.4f}')

In [None]:
# Interaktive Slider zur Steuerung der latenten Vektoren
def plot_latent_vector(**latent_values):
    z = torch.tensor([latent_values[f'z{i}'] for i in range(latent_dim)], dtype=torch.float32).to(device)
    with torch.no_grad():
        recon_image = vae.decode(z.unsqueeze(0)).cpu().detach().numpy().transpose(0, 2, 3, 1).squeeze()
        recon_image = (recon_image + 1) / 2  # Reskalieren auf [0, 1] für die Anzeige

    plt.imshow(recon_image)
    plt.title(f'Latent Vector: {latent_values}')
    plt.axis('off')
    plt.show()

# Erstelle die Slider für den latenten Vektor
sliders = [widgets.FloatSlider(min=-3.0, max=3.0, step=0.1, value=0.0, description=f'z{i}') for i in range(latent_dim)]

# Erstelle die Interaktivität
ui = widgets.VBox(sliders)
out = widgets.interactive_output(plot_latent_vector, {f'z{i}': sliders[i] for i in range(latent_dim)})

display(ui, out)

In [None]:
torch.save({
        'epoch': epoch,
        'model_state_dict': vae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
    }, f'vae_celeb_checkpoint_epoch_{epoch}.pth')

In [None]:
# Modell laden
checkpoint = torch.load('vae_celeb_checkpoint_epoch_9.pth')
vae.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
# train_loss = checkpoint['loss']

#vae.eval()  # Setze das Modell in den Evaluierungsmodus, falls erforderlich
# oder
vae.train()  # Setze das Modell in den Trainingsmodus, falls erforderlich

# References

This notebook is adapted from or uses following sources:
* 