In [None]:
import torch
from torch import nn

from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
import numpy as np

from vae import VAE

In [None]:
### DATA LOADING
testing_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [None]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(testing_data), size=(1,)).item()
    img, label = testing_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
vae = VAE().to(device)
vae.load_state_dict(torch.load('VAE_checkpoint.pth', weights_only=True))
vae.eval()

In [None]:
import random
import time

# picking random image to test from dataloader
random.seed(time.time())
sample_index = random.randint(0, len(testing_data))

img, label = testing_data[sample_index]
img = img.to(device)

pred_img = []
with torch.no_grad():
    pred_img = vae(img)


figure = plt.figure(figsize=(8, 8))
cols, rows = 2, 1

figure.add_subplot(rows, cols, 1)
plt.title(f"True: {labels_map[label]}")
plt.axis("off")
plt.imshow(img.cpu().squeeze(), cmap="gray")

figure.add_subplot(rows, cols, 2)
plt.title(f"VAE: {labels_map[label]}")
plt.axis("off")
plt.imshow(pred_img.cpu().squeeze(), cmap="gray")

plt.show()

In [None]:
testing_dataloader = DataLoader(datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()),
    batch_size=64,
    shuffle=True
)

for i, (X, y) in enumerate(testing_dataloader):
    with torch.no_grad():
        z = vae.encoder(X.to(device))
        z = z.to('cpu').numpy()

    plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')

    if i > 64:
        plt.colorbar()
        break

plt.show()

In [None]:
N = torch.distributions.Normal(0.0, 1.0)
N.loc = N.loc.cuda()
N.scale = N.scale.cuda()

gen_img = vae.decoder(N.sample((1, 10)))

plt.figure(figsize=(4, 4))
plt.axis("off")
plt.imshow(gen_img.cpu().detach().squeeze(), cmap="gray")

plt.show()
