# Image Generation with Variational Autoencoders

We'll build a variational autoencoder (VAE) to generate images of handwritten digits inspired by the MNIST dataset.

## The MNIST dataset

The MNIST dataset consists of grayscale images of hand-written digits from 0 to 9. Each image is 28 pixels by 28 pixels. There're 60,000 training images and 10,000 test images.

We've organized these images in two folders named `train` and `test` in the GitHub repository https://github.com/DeepTrackAI/MNIST_dataset:

In [None]:
import os

if not os.path.exists("MNIST_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/MNIST_dataset")

dir = os.path.join("MNIST_dataset", "mnist", "train")

In [None]:
print(f"{len(os.listdir(dir))} training images")

We load the data using `Deeptrack2.0`. We will only use `6000` images for training and `6000` for test. We normalize the images in the range `[0, 1]`

In [None]:
import deeptrack as dt
import torch

dt.config.disable_image_wrapper()                                               ### Why this line?

paths = dt.sources.ImageFolder(root=dir)
train_paths, test_paths, _ = dt.sources.random_split(paths, [0.1, 0.1, 0.8])

sources = dt.sources.Sources(train_paths, test_paths)

pipeline = (
    dt.LoadImage(sources.path)
    >> dt.NormalizeMinMax()
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

We will randomly choose and visualize some of the images in the dataset

In [None]:
import matplotlib.pyplot as plt
from numpy import random, squeeze

fig, axs = plt.subplots(3, 10, figsize=((10, 3)))
for ax, path in zip(axs.ravel(), random.choice(train_paths, axs.size)):
    image = pipeline(path)
    ax.imshow(squeeze(image), cmap="gray")
    ax.set_axis_off()

## Variational autoencoder
We define the autoencoder architecture with a bidimensional latent space.

In [None]:
import deeplay as dl
import torchmetrics as tm

vae = dl.VariationalAutoEncoder(latent_dim=2).create()

print(vae)

We define the datasets, the dataloader and the trainer. We train the autoencder for `100` epochs.

In [None]:
from torch.utils.data import DataLoader

train_dataset = dt.pytorch.Dataset(pipeline & pipeline, inputs=train_paths)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
vae_trainer = dl.Trainer(max_epochs=100, accelerator="auto")
vae_trainer.fit(vae, train_loader)

## Image Generation
We generate images by sampling a continuous distribution of latent representations and reconstructing the images using the trained VAE’s decoder.

In [None]:
from torch.distributions.normal import Normal
import numpy as np

NUM_OF_IMAGES, IMG_SIZE = 20, 28
grid_x = Normal(0, 1).icdf(torch.linspace(0.001, 0.999, NUM_OF_IMAGES))
grid_y = Normal(0, 1).icdf(torch.linspace(0.001, 0.999, NUM_OF_IMAGES))

image = np.zeros((IMG_SIZE * NUM_OF_IMAGES,) * 2)
for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z = torch.stack((xi, yi)).unsqueeze(0)
        gimg = vae.decode(z).clone().detach()
        image[
            i * IMG_SIZE : (i + 1) * IMG_SIZE,
            j * IMG_SIZE : (j + 1) * IMG_SIZE,
        ] = gimg.numpy().squeeze()
plt.figure(figsize=(10, 10))
plt.imshow(image, cmap="gray")
start = IMG_SIZE // 2
end = start + NUM_OF_IMAGES * IMG_SIZE
pixel_range = np.arange(start, end, IMG_SIZE)
sample_range_x = np.round(grid_x.numpy(), 1)
sample_range_y = np.round(grid_y.numpy(), 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z_0")
plt.ylabel("z_1")
plt.show()

## Clustering in the latent space
We will use the VAE as a clustering algorithm, i.e., to cluster the input images into different classes with respect to the latent space encoding.

We define a pipeline to get images and labels of the test dataset.

In [None]:
def get_label(label_name):
    """Get image label."""
    return int(label_name[0])

label = (
    dt.Value(get_label, label_name=sources.label_name)
    >> dt.Unsqueeze(0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

test_dataset = dt.pytorch.Dataset(pipeline & label, inputs=test_paths)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

We encode the images of the test dataset into latent space representation.

In [None]:
mu, test_labels = [], []
for x, y in test_loader:
    m, _ = vae.encode(x)
    mu.append(m)
    test_labels.append(y)
mu = torch.cat(mu, dim=0).detach().numpy()
test_labels = torch.cat(test_labels, dim=0).numpy()

We plot the latent-space representation, color coded according to the image label.

In [None]:
from matplotlib.patches import Rectangle

plt.figure(figsize=(12, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=test_labels, cmap="tab10")
plt.xlabel("mu[0]"), plt.ylabel("mu[1]"), plt.colorbar(), plt.axis('equal')
plt.gca().add_patch(Rectangle((-3.1, -3.1), 6.2, 6.2, facecolor="none", ec="k", lw=2))
plt.show()