# VQVAE for Image Generation - FashionMNIST Dataset
**Author:** Jeanne Malécot 

## Getting Started

In [1]:
#useful imports
import os
import copy

import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torchinfo
import random

import matplotlib.pyplot as plt

from scripts.train import train_model
from scripts.reconstruct import reconstruct, show_recon
from models.vqvae import VQVAE

##### Load FashionMNIST dataset

In [2]:
#load FashionMNIST
train_set = torchvision.datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

print(f"len train set: {len(train_set)}\nlen test set: {len(test_set)}")
print(f"image shape: {train_set[0][0][0].shape}")

len train set: 60000
len test set: 10000
image shape: torch.Size([28, 28])


## Training VQVAE for reconstruction

In [3]:
#device
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device:', device)

device: cuda


In [4]:
#config
config = {
    'n_epochs': 10, 
    'lr': 0.001, 
    'alpha': 0.25,
    'model': {
        'batch_size': 100, 
        'n_channels': 1, 
        'channels': [32, 64, 128],
        'latent_dim': 32,
        'n_embedding': 128,
        'beta': 1.2,
        }
    } 

Model architecture:

In [5]:
vqvae = vqvae = VQVAE(config['model']).to(device)
torchinfo.summary(vqvae, (1, 1, 28,28), device = str(device))


Layer (type:depth-idx)                        Output Shape              Param #
VQVAE                                         [1, 1, 28, 28]            --
├─EncoderBlock: 1-1                           [1, 32, 7, 7]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─Sequential: 3-1                   [1, 32, 28, 28]           320
│    │    └─Sequential: 3-2                   [1, 64, 14, 14]           18,496
│    │    └─Sequential: 3-3                   [1, 128, 7, 7]            73,856
│    │    └─Sequential: 3-4                   [1, 32, 7, 7]             4,128
├─VectorQuantizer: 1-2                        [1, 32, 7, 7]             4,096
├─DecoderBlock: 1-3                           [1, 1, 56, 56]            --
│    └─ModuleList: 2-2                        --                        --
│    │    └─Sequential: 3-5                   [1, 128, 7, 7]            4,224
│    │    └─Sequential: 3-6                   [1, 64, 14, 14]           32,83

#### Fine tuning of the VQ-VAE

In [None]:
# config = copy.deepcopy(basic_config)
model_dict, loss_dict = train_model("vqvae", train_set, config, device)
print(f"Loss: {model_dict['train_loss']:.4f}")

Output()

In [None]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress  ",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
model = model_dict['model']

image_dicts, label_dist = reconstruct(model, test_set, device)
show_recon(image_dicts)

In [None]:
loss = loss_dict['loss']
vq_loss = loss_dict['vq_loss']
reconstruction_loss = loss_dict['reconstruction_loss']

x = np.arange(len(loss))

plt.figure()

plt.plot(x, loss, label='final loss')
plt.plot(x, vq_loss, label='vq_loss')
plt.plot(x, reconstruction_loss, label='reconstruction loss')

plt.legend()
plt.ylim(ymin=0)

plt.show()

In [None]:
#convert vector into p_dist
def to_prob(vector, temperature=1):
    if temperature <= 0:
        raise ValueError("Carreful ! Temperature must be greater than 0.")
    
    scaled_vector = vector / temperature
    prob_dist = torch.softmax(scaled_vector, dim=0)
    return prob_dist

In [None]:
#generate sneakers
dist = (label_dist[7])

sneakers_dist= to_prob(dist, torch.std(dist)**2)
# print(sneakers_dist)

generated_image = model.generate(sneakers_dist, (28,28))
print(generated_image.shape)
plt.imshow(generated_image.detach().cpu().numpy(), cmap = 'gray')