# Vector Quantized VAE (VQ-VAE) implementation from scratch

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from PIL import Image
Image.LOAD_TRUNCATED_IMAGES = True
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm.notebook import tqdm
import zipfile

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [None]:
class Encoder(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, kernel_size=(4, 4, 3, 1), stride=2):
        super(Encoder, self).__init__()

        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_size

        self.strided_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, stride, padding=1)
        self.strided_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, stride, padding=1)

        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_3, padding=1)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, output_dim, kernel_4, padding=0)

    def forward(self, x):

        x = self.strided_conv_1(x)
        x = self.strided_conv_2(x)

        x = nn.functional.relu(x)
        y = self.residual_conv_1(x)
        y += x

        x = nn.functional.relu(y)
        y = self.residual_conv_2(x)
        y += x

        return y

In [None]:
class VQEmbeddingEMA(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
        super(VQEmbeddingEMA, self).__init__()
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon

        init_bound = 1 / n_embeddings
        embedding = torch.Tensor(n_embeddings, embedding_dim)
        embedding.uniform_(-init_bound, init_bound)
        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.zeros(n_embeddings))
        self.register_buffer("ema_weight", self.embedding.clone())

    def encode(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)

        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                    torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                x_flat, self.embedding.t(),
                                alpha=-2.0, beta=1.0)

        indices = torch.argmin(distances.float(), dim=-1)
        quantized = nn.functional.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)
        return quantized, indices.view(x.size(0), x.size(1))

    def retrieve_random_codebook(self, random_indices):
        quantized = nn.functional.embedding(random_indices, self.embedding)
        quantized = quantized.transpose(1, 3)

        return quantized

    def forward(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)

        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                                torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                x_flat, self.embedding.t(),
                                alpha=-2.0, beta=1.0)

        indices = torch.argmin(distances.float(), dim=-1)
        encodings = nn.functional.one_hot(indices, M).float()
        quantized = nn.functional.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)

        if self.training:
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
            n = torch.sum(self.ema_count)
            self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n

            dw = torch.matmul(encodings.t(), x_flat)
            self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
            self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

        codebook_loss = nn.functional.mse_loss(x.detach(), quantized)
        e_latent_loss = nn.functional.mse_loss(x, quantized.detach())
        commitment_loss = self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, commitment_loss, codebook_loss, perplexity

In [None]:
class Decoder(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, kernel_sizes=(1, 3, 2, 2), stride=2):
        super(Decoder, self).__init__()

        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_sizes

        self.residual_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, padding=0)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, padding=1)

        self.strided_t_conv_1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_3, stride, padding=0)
        self.strided_t_conv_2 = nn.ConvTranspose2d(hidden_dim, output_dim, kernel_4, stride, padding=0)

    def forward(self, x):

        y = self.residual_conv_1(x)
        y += x
        x = nn.functional.relu(y)

        y = self.residual_conv_2(x)
        y += x
        y = nn.functional.relu(y)

        y = self.strided_t_conv_1(y)
        y = self.strided_t_conv_2(y)

        return y

In [None]:
class VQVAE(nn.Module):
    def __init__(self, Encoder, Codebook, Decoder):
        super(VQVAE, self).__init__()
        self.encoder = Encoder
        self.codebook = Codebook
        self.decoder = Decoder

    def forward(self, x):
        z = self.encoder(x)
        z_quantized, commitment_loss, codebook_loss, perplexity = self.codebook(z)
        x_hat = self.decoder(z_quantized)

        return x_hat, commitment_loss, codebook_loss, perplexity


In [None]:
input_dim = 3
hidden_dim = 64
#latent_dim = 64
n_embeddings = 512


In [None]:
encode = Encoder(input_dim, hidden_dim, hidden_dim)
codebook = VQEmbeddingEMA(n_embeddings, hidden_dim)
decode = Decoder(hidden_dim, hidden_dim, input_dim)

model = VQVAE(encode, codebook, decode).to(DEVICE)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()


In [None]:
epochs = 20
print_step = 100

In [None]:
transform = transforms.Compose([
    transforms.Resize((112,112)),
    transforms.ToTensor()
])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
if True:
  zip_ref = zipfile.ZipFile('/content/drive/MyDrive/trial.zip', 'r')
  zip_ref.extractall()

In [None]:
dataset = datasets.ImageFolder("./trial/", transform=transform)

In [None]:
train_loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
len(train_loader)/4

123.0

In [None]:
def save_checkpoint(model, loss, perplexity, recon_loss, codebook_loss, filename):
    state = {
        "state_dict": model.state_dict(),
        "loss": loss,
        "perplexity": perplexity,
        "recon_loss": recon_loss,
        "codebook_loss": codebook_loss
    }
    torch.save(state, filename)

In [None]:
print("Start training VQ-VAE...")
model.train()

tracked_loss = []
tracked_perplexity = []
tracked_recon_loss = []
tracked_codebook_loss = []


for epoch in tqdm(range(epochs)):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
        recon_loss = mse_loss(x_hat, x)

        loss =  recon_loss + commitment_loss + codebook_loss

        loss.backward()
        optimizer.step()

        if batch_idx % print_step == 0:
            print("epoch:", epoch + 1, "  step:", batch_idx + 1, "  recon_loss:", recon_loss.item(), "  perplexity: ", perplexity.item(),
              "\n\t\tcommit_loss: ", commitment_loss.item(), "  codebook loss: ", codebook_loss.item(), "  total_loss: ", loss.item())

    tracked_loss.append(loss.item())
    tracked_perplexity.append(perplexity.item())
    tracked_recon_loss.append(recon_loss.item())
    tracked_codebook_loss.append(codebook_loss.item())

    save_checkpoint(model, tracked_loss, tracked_perplexity, tracked_recon_loss, tracked_codebook_loss, filename="VQ_VAE_128.pth.tar")

print("Finish!!")

Start training VQ-VAE...


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 1   step: 1   recon_loss: 0.1866087019443512   perplexity:  36.73640060424805 
		commit_loss:  0.0044831084087491035   codebook loss:  0.017932433634996414   total_loss:  0.20902423560619354
epoch: 1   step: 101   recon_loss: 0.015706485137343407   perplexity:  182.4307098388672 
		commit_loss:  0.01364827249199152   codebook loss:  0.05459308996796608   total_loss:  0.08394785225391388
epoch: 1   step: 201   recon_loss: 0.014088070951402187   perplexity:  289.3202819824219 
		commit_loss:  0.015485777519643307   codebook loss:  0.06194311007857323   total_loss:  0.09151695668697357
epoch: 1   step: 301   recon_loss: 0.011030398309230804   perplexity:  334.3252258300781 
		commit_loss:  0.016161199659109116   codebook loss:  0.06464479863643646   total_loss:  0.09183639287948608
epoch: 1   step: 401   recon_loss: 0.010862873867154121   perplexity:  373.29534912109375 
		commit_loss:  0.01920594833791256   codebook loss:  0.07682379335165024   total_loss:  0.10689261555671692
epo

In [None]:
from google.colab import files
files.download('VQ_VAE_128.pth.tar')
