In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import cv2
import numpy as np
torch.manual_seed(41)

### 1. Load MNIST dataset

In [None]:
# 1. Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=False)


### 2. Define VQVAE model

In [None]:

class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, 4, stride=4, padding=0),
            nn.BatchNorm2d(4),
            nn.ReLU()
        )

        self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1)
        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=2)
        self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1)

        # Commitment Loss Beta
        self.beta = 0.2

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # B, C, H, W
        encoded_output = self.encoder(x)
        quant_input = self.pre_quant_conv(encoded_output)

        ## Quantization
        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1)
        quant_input = quant_input.reshape(
            (quant_input.size(0), -1, quant_input.size(-1))
        )

        # Compute pairwise distances
        dist = torch.cdist(
            quant_input,
            self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)),
        )

        # Find index of nearest embedding
        min_encoding_indices = torch.argmin(dist, dim=-1)

        # Select the embedding weights
        quant_out = torch.index_select(
            self.embedding.weight, 0, min_encoding_indices.view(-1)
        )
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))

        # Compute losses
        commitment_loss = torch.mean((quant_out.detach() - quant_input) ** 2)
        codebook_loss = torch.mean((quant_out - quant_input.detach()) ** 2)
        quantize_losses = codebook_loss + commitment_loss * 0.1
        quant_out = quant_input + (quant_out - quant_input).detach()

        # Reshaping back to original input shape
        quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
        min_encoding_indices = min_encoding_indices.reshape(
            (-1, quant_out.size(-2), quant_out.size(-1))
        )

        ## Decoder part
        decoder_input = self.post_quant_conv(quant_out)
        output = self.decoder(decoder_input)
        return output, quantize_losses

### 3. Initialize model, optimizer, scheduler

In [None]:
model = VQVAE().cuda()
optimizer = optim.AdamW(model.parameters(), lr=2e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

### 4. Training loop

In [None]:
model.train()
for epoch in range(20):
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        out, quantize_loss = model(data)
        recon_loss = torch.nn.functional.mse_loss(out, data)
        loss = recon_loss + quantize_loss
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        print(
            f"\rEpoch {epoch}, Batch {batch_idx:03d}, Loss: {loss.item():.4f} = {recon_loss.item():.4f} + {quantize_loss.item():.4f}",
            end="",
        )
    scheduler.step()
    print("")

# torch.save(model.state_dict(), "vqvae.ckpt")

### 5. Inference

In [None]:
def save(data, idx):
    data = (data * 255).astype(np.uint8)
    cv2.imwrite(f"img_{idx}.png", data)


model.eval()
with torch.no_grad():
    for i in range(10):
        data, _ = train_dataset[i]
        data = data.unsqueeze(0).cuda()
        x_recon, _ = model(data)
        recon_img = x_recon.cpu().squeeze().numpy()
        # Save or display recon_img
        save(x_recon, i)
