In [9]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import numpy as np

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Fashion MNIST dataset


In [11]:
train_set = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    transform=transforms.Compose([transforms.ToTensor()]),
)
test_set = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    train=False,
    transform=transforms.Compose([transforms.ToTensor()]),
)

In [12]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False)

We print an example of the dataset to understand the structure of the data.


In [None]:
image, label = next(iter(train_set))
plt.imshow(image.squeeze(), cmap="gray")
print(label)

### Exploratory Data Analysis


In [None]:
from sklearn.decomposition import PCA

np.random.seed(0)
indices = np.random.choice(len(train_set), 2000, replace=False)
train_subset = torch.utils.data.Subset(train_set, indices)

images_flat = np.array(
    [image.flatten().numpy() for image, _ in train_subset], dtype=np.float32
)
labels = np.array([label for _, label in train_subset])

pca = PCA(n_components=2)
pca_result = pca.fit_transform(images_flat)

scatter = plt.scatter(
    pca_result[:, 0], pca_result[:, 1], c=labels, cmap="tab10", alpha=0.7, s=5
)
colorbar = plt.colorbar(scatter)
colorbar.set_ticks(range(10))
colorbar.set_ticklabels(train_set.classes)
plt.title("2-dim PCA of Fashion MNIST")
plt.show()

In [None]:
# do the same with umap
import umap

reducer = umap.UMAP()
umap_result = reducer.fit_transform(images_flat)

scatter = plt.scatter(
    umap_result[:, 0], umap_result[:, 1], c=labels, cmap="tab10", alpha=0.7, s=5
)
colorbar = plt.colorbar(scatter)
colorbar.set_ticks(range(10))
colorbar.set_ticklabels(train_set.classes)
plt.title("2-dim UMAP of Fashion MNIST")
plt.show()

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
tsne_result = tsne.fit_transform(images_flat)

scatter = plt.scatter(
    tsne_result[:, 0], tsne_result[:, 1], c=labels, cmap="tab10", alpha=0.7, s=5
)
colorbar = plt.colorbar(scatter)
colorbar.set_ticks(range(10))
colorbar.set_ticklabels(train_set.classes)
plt.title("2-dim t-SNE of Fashion MNIST")
plt.show()

### Define the model


In [36]:
class ResidualLayer(nn.Module):
    """
    One residual layer inputs:
    - in_dim : the input dimension
    - res_h_dim : the hidden dimension of the residual block
    """

    def __init__(self, in_dim, res_h_dim):
        super(ResidualLayer, self).__init__()

        self.res_block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_dim, res_h_dim, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.ReLU(),
            nn.Conv2d(res_h_dim, in_dim, kernel_size=1, stride=1, bias=False),
        )

    def forward(self, x):
        x = x + self.res_block(x)
        return x


class ResidualStack(nn.Module):
    """
    A stack of residual layers inputs:
    - in_dim : the input dimension
    - res_h_dim : the hidden dimension of the residual block
    - n_res_layers : number of layers to stack
    """

    def __init__(self, in_dim, res_h_dim, n_res_layers):
        super(ResidualStack, self).__init__()

        self.n_res_layers = n_res_layers
        self.stack = nn.ModuleList(
            [ResidualLayer(in_dim, res_h_dim) for _ in range(n_res_layers)]
        )

    def forward(self, x):
        for layer in self.stack:
            x = layer(x)
        x = F.relu(x)
        return x

In [37]:
class Encoder(nn.Module):
    def __init__(self, n_channels, hidden_dim, output_dim):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(
                n_channels, hidden_dim // 2, kernel_size=4, stride=2, padding=1
            ),  # Input: [n_channels, 28, 28] Output: [hidden_dim // 2, 14, 14]
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(
                hidden_dim // 2, hidden_dim, kernel_size=4, stride=2, padding=1
            ),  # Input: [hidden_dim // 2, 14, 14] Output: [hidden_dim, 7, 7]
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(
                hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1
            ),  # Input: [hidden_dim, 7, 7] Output: [hidden_dim, 7, 7]
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            ResidualStack(
                hidden_dim, hidden_dim, hidden_dim, 2
            ),  # Input: [hidden_dim, 7, 7] Output: [hidden_dim, 7, 7]
            nn.Conv2d(
                hidden_dim, output_dim, kernel_size=3, stride=1, padding=1
            ),  # Input: [hidden_dim, 7, 7] Output: [output_dim, 7, 7]
        )

    def forward(self, x):
        return self.encoder(x)

In [38]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        # the dictionary of embeddings
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(
            -1 / self.num_embeddings, 1 / self.num_embeddings
        )
        self.commitment_cost = commitment_cost

    def forward(self, inputs: torch.Tensor):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self.embedding_dim)

        # Calculate Eucledian distance between input and embeddings
        distances = (
            torch.sum(flat_input**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(flat_input, self.embedding.weight.t())
        )

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        # one-hot encoding
        encodings = torch.zeros(
            encoding_indices.shape[0], self.num_embeddings, device=inputs.device
        )
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = inputs + (quantized - inputs).detach()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [39]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_channels):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                input_dim, hidden_dim, kernel_size=3, stride=1, padding=1
            ),  # Input: [input_dim, 7, 7] Output: [hidden_dim, 7, 7]
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            ResidualStack(
                hidden_dim, hidden_dim, hidden_dim, 2
            ),  # Input: [hidden_dim, 7, 7] Output: [hidden_dim, 7, 7]
            nn.ConvTranspose2d(
                hidden_dim, hidden_dim // 2, kernel_size=4, stride=2, padding=1
            ),  # Input: [hidden_dim, 7, 7] Output: [hidden_dim // 2, 14, 14]
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(
                hidden_dim // 2, n_channels, kernel_size=4, stride=2, padding=1
            ),  # Input: [hidden_dim // 2, 14, 14] Output: [n_channels, 28, 28]
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(x)

In [40]:
class VQVAE(nn.Module):
    def __init__(self, n_channels, hidden_dim, num_embeddings, embedding_dim):
        super(VQVAE, self).__init__()

        self.encoder = Encoder(
            n_channels, hidden_dim, embedding_dim
        )  # Input: [n_channels, 28, 28], Output: [embedding_dim, 7, 7]
        self.quantizer = VectorQuantizer(
            num_embeddings, embedding_dim, 0.25
        )  # Input: [embedding_dim, 7, 7], Output: [embedding_dim, 7, 7]
        self.decoder = Decoder(
            embedding_dim, hidden_dim, n_channels
        )  # Input: [embedding_dim, 7, 7], Output: [n_channels, 28, 28]

    def forward(self, x):
        z = self.encoder(x)
        loss, quantized, perplexity, _ = self.quantizer(z)
        x_recon = self.decoder(quantized)

        return x_recon, loss, perplexity

### Training loop


In [None]:
n_channels = next(iter(train_set))[0].shape[0]
HIDDEN_DIM = 64
NUM_EMBEDDINGS = 512
EMBEDDING_DIM = 64

model = VQVAE(n_channels, HIDDEN_DIM, NUM_EMBEDDINGS, EMBEDDING_DIM).to(device)
print(model)

In [None]:
def train(train_args):
    model.train()
    train_loss = 0
    optimizer = optim.AdamW(model.parameters(), lr=train_args["lr"])

    for i, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        x_recon, loss, perplexity = model(data)
        recon_error = F.mse_loss(x_recon, data)
        loss = recon_error + loss

        loss.backward()
        optimizer.step()

        if i % 100 == 0 and train_args["verbose"] == True:
            print(
                f"Step: {i}, Loss: {loss.item()}, Recon Error: {recon_error.item()}, Perplexity: {perplexity.item()}"
            )
    return train_loss

In [None]:
epochs = 10
learning_rates = [1e-4]  # [1e-3, 1e-4, 1e-5]
best_lr = None
best_loss = float("inf")

for lr in learning_rates:
    print(f"Training with learning rate: {lr}")
    for e in range(epochs):
        print(f"Epoch: {e}")
        train_loss = train({"lr": lr, "verbose": False})

        if train_loss < best_loss:
            best_loss = train_loss
            best_lr = lr

### Evaluate the model


In [103]:
model.eval()

with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)
        x_recon, _, _ = model(data)
        break

In [None]:
fig, axs = plt.subplots(2, 8, figsize=(16, 4))
fig.tight_layout()

indices = torch.randint(0, 100, (8,))
for i, idx in enumerate(indices):
    axs[0, i].imshow(data[idx].squeeze().cpu(), cmap="gray")
    axs[1, i].imshow(x_recon[idx].squeeze().cpu(), cmap="gray")

plt.show()

In [None]:
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import mean_squared_error

ssim_val = ssim(
    data.cpu().numpy().squeeze(1), x_recon.cpu().numpy().squeeze(1), data_range=1.0
)
mse_val = mean_squared_error(
    data.cpu().numpy().flatten(), x_recon.cpu().numpy().flatten()
)

print(f"SSIM: {ssim_val}, MSE: {mse_val}")