In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

from models.vqvae import VQVAE

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Fashion MNIST dataset


In [None]:
train_set = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    transform=transforms.Compose([transforms.ToTensor()]),
)
test_set = torchvision.datasets.FashionMNIST(
    "./data",
    download=True,
    train=False,
    transform=transforms.Compose([transforms.ToTensor()]),
)

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100)

We print an example of the dataset to understand the structure of the data.


In [None]:
image, label = next(iter(train_set))
plt.imshow(image.squeeze(), cmap="gray")
print(label)

### Training loop


In [None]:
model = VQVAE().to(device)
train_args = {"epochs": 1, "lr": 1e-3}
optimizer = optim.AdamW(model.parameters(), lr=train_args["lr"])
model.train()


def train():
    model.train()
    train_loss = 0

    for i, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        x_recon, loss, perplexity = model(data)
        recon_error = F.mse_loss(x_recon, data)
        loss = recon_error + loss

        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(
                f"Step: {i}, Loss: {loss.item()}, Recon Error: {recon_error.item()}, Perplexity: {perplexity.item()}"
            )
    return train_loss


for epoch in range(train_args["epochs"]):
    print(f"Epoch: {epoch}")
    train()

### Test the model


In [None]:
model.eval()

with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        print(i)
        data = data.to(device)
        x_recon, _, _ = model(data)
        break

In [None]:
fig, axs = plt.subplots(2, 5)

indices = torch.randint(0, 100, (5,))
for i, idx in enumerate(indices):
    axs[0, i].imshow(data[idx].squeeze().cpu(), cmap="gray")
    axs[1, i].imshow(x_recon[idx].squeeze().cpu(), cmap="gray")

plt.show()