# Image Generation with Variational Autoencoders

We'll build a variational autoencoder (VAE) to generate images of handwritten digits inspired by the MNIST dataset.

## The MNIST dataset

The MNIST dataset consists of grayscale images of hand-written digits from 0 to 9. Each image is 28 pixels by 28 pixels. There're 60,000 training images and 10,000 test images.

We've organized these images in two folders named `train` and `test` in the GitHub repository https://github.com/DeepTrackAI/MNIST_dataset:

In [None]:
import os

if not os.path.exists("MNIST_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/MNIST_dataset")

dir = os.path.join("MNIST_dataset", "mnist")                                 # BM: Might want to rename from dir. root_dir? dir is a built-in function in python and it's not good practice to overwrite them.

In [None]:
import deeptrack as dt

train_files = dt.sources.ImageFolder(root=os.path.join(dir, "train"))
test_files = dt.sources.ImageFolder(root=os.path.join(dir, "test"))
files = dt.sources.Join(train_files, test_files)

print(f"{len(train_files)} training images")
print(f"{len(test_files)} test images")

We load the data using `Deeptrack2.0`. We normalize the images in the range `[0, 1]`

In [None]:
import torch

image_pipeline = (
    dt.LoadImage(files.path)
    >> dt.NormalizeMinMax()
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

We will randomly choose and visualize some of the images in the dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(3, 10, figsize=((10, 3)))
for ax, train_file in zip(axs.ravel(), np.random.choice(train_files, axs.size)):
    image = image_pipeline(train_file)
    ax.imshow(image.squeeze(), cmap="gray")
    ax.set_axis_off()

## Variational autoencoder
We define the autoencoder architecture with a bidimensional latent space.

In [None]:
import deeplay as dl

vae = dl.VariationalAutoEncoder(latent_dim=2).create()
print(vae)

We define the datasets, the dataloader and the trainer. We train the autoencder for `100` epochs.

In [None]:
from torch.utils.data import DataLoader

train_dataset = dt.pytorch.Dataset(image_pipeline & image_pipeline, 
                                   inputs=train_files)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
vae_trainer = dl.Trainer(max_epochs=100, accelerator="auto")
vae_trainer.fit(vae, train_loader)

## Image Generation
We generate images by sampling a continuous distribution of latent representations and reconstructing the images using the trained VAE’s decoder.

In [None]:
from torch.distributions.normal import Normal

img_num, img_size = 21, 28

z0_grid = Normal(0, 1).icdf(torch.linspace(0.001, 0.999, img_num))
z1_grid = Normal(0, 1).icdf(torch.linspace(0.001, 0.999, img_num))

image = np.zeros((img_num * img_size, img_num * img_size))

for i0, z0 in enumerate(z0_grid):
    for i1, z1 in enumerate(z1_grid):
        z = torch.stack((z0, z1)).unsqueeze(0)
        generated_image = vae.decode(z).clone().detach()
        image[
            i1 * img_size : (i1 + 1) * img_size,
            i0 * img_size : (i0 + 1) * img_size,
        ] = generated_image.numpy().squeeze()

plt.figure(figsize=(10, 10))
plt.imshow(image, cmap="gray")
plt.xticks(np.arange(0.5 * img_size, (0.5 + img_num) * img_size, img_size), 
           np.round(z0_grid.numpy(), 1))
plt.yticks(np.arange(0.5 * img_size, (0.5 + img_num) * img_size, img_size), 
           np.round(z1_grid.numpy(), 1))
plt.xlabel("z0", fontsize=20)
plt.ylabel("z1", fontsize=20)
plt.show()

## Clustering in the latent space
We will use the VAE as a clustering algorithm, i.e., to cluster the input images into different classes with respect to the latent space encoding.

We define a pipeline to get images and labels of the test dataset.

In [None]:
label_pipeline = dt.Value(files.label_name[0]) >> int

In [None]:
test_dataset = dt.pytorch.Dataset(image_pipeline & label_pipeline, inputs=test_files)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

We encode the images of the test dataset into latent space representation.

In [None]:
z_list, test_labels = [], []
for image, label in test_loader:
    z, _ = vae.encode(image)
    z_list.append(z)
    test_labels.append(label)
z_tensor = torch.cat(z_list, dim=0).detach().numpy()
test_labels = torch.cat(test_labels, dim=0).numpy()

We plot the latent-space representation, color coded according to the image label.

In [None]:
from matplotlib.patches import Rectangle

plt.figure(figsize=(12, 10))
plt.scatter(z_tensor[:, 0], z_tensor[:, 1], s=3, c=test_labels, cmap="tab10")
plt.xlabel("z_tensor[:, 0]")
plt.ylabel("z_tensor[:, 1]")
plt.colorbar()
plt.axis('equal')
plt.gca().add_patch(Rectangle((-3.1, -3.1), 6.2, 6.2, fc="none", ec="k", lw=1))
plt.show()