In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

In [2]:
# Model Hyperparameters

dataset_path = '~/datasets'

cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")


batch_size = 128
img_size = (32, 32) # (width, height)

input_dim = 3
hidden_dim = 128
n_embeddings= 768
output_dim = 3

lr = 2e-4

epochs = 50

print_step = 50

###    Step 1. Load (or download) Dataset

In [3]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

train_dataset = CIFAR10(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = CIFAR10(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False,  **kwargs)

Files already downloaded and verified
Files already downloaded and verified


### Step 2. Define our model: Vector Quantized Variational AutoEncoder (VQ-VAE)

In [4]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, kernel_size=(4, 4, 3, 1), stride=2):
        super(Encoder, self).__init__()
        
        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_size
        
        self.strided_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, stride, padding=1)
        self.strided_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, stride, padding=1)
        
        self.residual_conv_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_3, padding=1)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, output_dim, kernel_4, padding=0)
        
    def forward(self, x):
        
        x = self.strided_conv_1(x)
        x = self.strided_conv_2(x)
        
        x = F.relu(x)
        y = self.residual_conv_1(x)
        y = y+x
        
        x = F.relu(y)
        y = self.residual_conv_2(x)
        y = y+x
        
        return y

In [5]:
class VQEmbeddingEMA(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
        super(VQEmbeddingEMA, self).__init__()
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
        
        init_bound = 1 / n_embeddings
        embedding = torch.Tensor(n_embeddings, embedding_dim)
        embedding.uniform_(-init_bound, init_bound)
        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.zeros(n_embeddings))
        self.register_buffer("ema_weight", self.embedding.clone())

    def encode(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)

        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                    torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                x_flat, self.embedding.t(),
                                alpha=-2.0, beta=1.0)

        indices = torch.argmin(distances.float(), dim=-1)
        quantized = F.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)
        return quantized, indices.view(x.size(0), x.size(1))
    
    def retrieve_random_codebook(self, random_indices):
        quantized = F.embedding(random_indices, self.embedding)
        quantized = quantized.transpose(1, 3)
        
        return quantized

    def forward(self, x):
        M, D = self.embedding.size()
        x_flat = x.detach().reshape(-1, D)
        
        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                                torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                x_flat, self.embedding.t(),
                                alpha=-2.0, beta=1.0)

        indices = torch.argmin(distances.float(), dim=-1)
        encodings = F.one_hot(indices, M).float()
        quantized = F.embedding(indices, self.embedding)
        quantized = quantized.view_as(x)
        
        if self.training:
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
            n = torch.sum(self.ema_count)
            self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n

            dw = torch.matmul(encodings.t(), x_flat)
            self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
            self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

        codebook_loss = F.mse_loss(x.detach(), quantized)
        e_latent_loss = F.mse_loss(x, quantized.detach())
        commitment_loss = self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()

        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return quantized, commitment_loss, codebook_loss, perplexity

In [6]:
class Decoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, kernel_sizes=(1, 3, 2, 2), stride=2):
        super(Decoder, self).__init__()
        
        kernel_1, kernel_2, kernel_3, kernel_4 = kernel_sizes
        
        self.residual_conv_1 = nn.Conv2d(input_dim, hidden_dim, kernel_1, padding=0)
        self.residual_conv_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_2, padding=1)
        
        self.strided_t_conv_1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_3, stride, padding=0)
        self.strided_t_conv_2 = nn.ConvTranspose2d(hidden_dim, output_dim, kernel_4, stride, padding=0)
        
    def forward(self, x):
        
        y = self.residual_conv_1(x)
        y = y+x
        x = F.relu(y)
        
        y = self.residual_conv_2(x)
        y = y+x
        y = F.relu(y)
        
        y = self.strided_t_conv_1(y)
        y = self.strided_t_conv_2(y)
        
        return y

In [7]:
class Model(nn.Module):
    def __init__(self, Encoder, Codebook, Decoder):
        super(Model, self).__init__()
        self.encoder = Encoder
        self.codebook = Codebook
        self.decoder = Decoder
                
    def forward(self, x):
        z = self.encoder(x)
        z_quantized, commitment_loss, codebook_loss, perplexity = self.codebook(z)
        x_hat = self.decoder(z_quantized)
        
        return x_hat, commitment_loss, codebook_loss, perplexity
    
    

In [8]:
encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim)
codebook = VQEmbeddingEMA(n_embeddings=n_embeddings, embedding_dim=hidden_dim)
decoder = Decoder(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=output_dim)

model = Model(Encoder=encoder, Codebook=codebook, Decoder=decoder).to(DEVICE)

### Step 3. Define Loss function (reprod. loss) and optimizer

In [9]:
from torch.optim import Adam

mse_loss = nn.MSELoss()

optimizer = Adam(model.parameters(), lr=lr)

### Step 4. Train Vector Quantized Variational AutoEncoder (VQ-VAE)

In [None]:
print("Start training VQ-VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
        recon_loss = mse_loss(x_hat, x)
        
        loss =  recon_loss + commitment_loss + codebook_loss
                
        loss.backward()
        optimizer.step()
        
        if batch_idx % print_step ==0: 
            print("epoch:", epoch + 1, "  step:", batch_idx + 1, "  recon_loss:", recon_loss.item(), "  perplexity: ", perplexity.item(), 
              "\n\t\tcommit_loss: ", commitment_loss.item(), "  codebook loss: ", codebook_loss.item(), "  total_loss: ", loss.item())
    
print("Finish!!")

Start training VQ-VAE...
epoch: 1   step: 1   recon_loss: 0.12986251711845398   perplexity:  47.204471588134766 
		commit_loss:  0.006001153960824013   codebook loss:  0.02400461584329605   total_loss:  0.1598682999610901
epoch: 1   step: 51   recon_loss: 0.04567872732877731   perplexity:  42.20817184448242 
		commit_loss:  0.034285638481378555   codebook loss:  0.13714255392551422   total_loss:  0.2171069234609604
epoch: 1   step: 101   recon_loss: 0.03384900838136673   perplexity:  89.93128204345703 
		commit_loss:  0.02993432804942131   codebook loss:  0.11973731219768524   total_loss:  0.18352064490318298
epoch: 1   step: 151   recon_loss: 0.027672773227095604   perplexity:  124.88264465332031 
		commit_loss:  0.03181006759405136   codebook loss:  0.12724027037620544   total_loss:  0.18672311305999756
epoch: 1   step: 201   recon_loss: 0.026118455454707146   perplexity:  154.0612335205078 
		commit_loss:  0.03789319470524788   codebook loss:  0.15157277882099152   total_loss:  0.21

epoch: 6   step: 101   recon_loss: 0.010875707492232323   perplexity:  560.1865844726562 
		commit_loss:  0.02711101993918419   codebook loss:  0.10844407975673676   total_loss:  0.14643080532550812
epoch: 6   step: 151   recon_loss: 0.01098853349685669   perplexity:  591.0968627929688 
		commit_loss:  0.02788493037223816   codebook loss:  0.11153972148895264   total_loss:  0.15041318535804749
epoch: 6   step: 201   recon_loss: 0.010916754603385925   perplexity:  581.0225219726562 
		commit_loss:  0.028284473344683647   codebook loss:  0.11313789337873459   total_loss:  0.1523391306400299
epoch: 6   step: 251   recon_loss: 0.010993203148245811   perplexity:  594.9034423828125 
		commit_loss:  0.0285346582531929   codebook loss:  0.1141386330127716   total_loss:  0.15366649627685547
epoch: 6   step: 301   recon_loss: 0.010561942122876644   perplexity:  587.4264526367188 
		commit_loss:  0.02696807123720646   codebook loss:  0.10787228494882584   total_loss:  0.14540229737758636
epoch: 6

epoch: 11   step: 201   recon_loss: 0.009423065930604935   perplexity:  622.7198486328125 
		commit_loss:  0.027028506621718407   codebook loss:  0.10811402648687363   total_loss:  0.14456559717655182
epoch: 11   step: 251   recon_loss: 0.009687854908406734   perplexity:  622.070556640625 
		commit_loss:  0.02676006406545639   codebook loss:  0.10704025626182556   total_loss:  0.14348816871643066
epoch: 11   step: 301   recon_loss: 0.009648793376982212   perplexity:  630.8403930664062 
		commit_loss:  0.0264489334076643   codebook loss:  0.1057957336306572   total_loss:  0.14189346134662628
epoch: 11   step: 351   recon_loss: 0.009450634010136127   perplexity:  638.7750244140625 
		commit_loss:  0.027651259675621986   codebook loss:  0.11060503870248795   total_loss:  0.14770692586898804
epoch: 12   step: 1   recon_loss: 0.009017271921038628   perplexity:  605.3821411132812 
		commit_loss:  0.02586526609957218   codebook loss:  0.10346106439828873   total_loss:  0.13834360241889954
epo

epoch: 16   step: 301   recon_loss: 0.00853740330785513   perplexity:  626.6061401367188 
		commit_loss:  0.02685229852795601   codebook loss:  0.10740919411182404   total_loss:  0.14279890060424805
epoch: 16   step: 351   recon_loss: 0.00862729363143444   perplexity:  635.871826171875 
		commit_loss:  0.026677733287215233   codebook loss:  0.10671093314886093   total_loss:  0.1420159637928009
epoch: 17   step: 1   recon_loss: 0.008372170850634575   perplexity:  621.5608520507812 
		commit_loss:  0.025481699034571648   codebook loss:  0.10192679613828659   total_loss:  0.13578066229820251
epoch: 17   step: 51   recon_loss: 0.008781087584793568   perplexity:  648.1851806640625 
		commit_loss:  0.028656475245952606   codebook loss:  0.11462590098381042   total_loss:  0.15206345915794373
epoch: 17   step: 101   recon_loss: 0.009105941280722618   perplexity:  647.7300415039062 
		commit_loss:  0.02814292535185814   codebook loss:  0.11257170140743256   total_loss:  0.14982056617736816
epoc

epoch: 22   step: 1   recon_loss: 0.008469371125102043   perplexity:  644.9039306640625 
		commit_loss:  0.02729184553027153   codebook loss:  0.10916738212108612   total_loss:  0.14492860436439514
epoch: 22   step: 51   recon_loss: 0.007823950611054897   perplexity:  632.595947265625 
		commit_loss:  0.02530590258538723   codebook loss:  0.10122361034154892   total_loss:  0.13435345888137817
epoch: 22   step: 101   recon_loss: 0.007742434274405241   perplexity:  635.242431640625 
		commit_loss:  0.025445587933063507   codebook loss:  0.10178235173225403   total_loss:  0.13497036695480347
epoch: 22   step: 151   recon_loss: 0.008194840513169765   perplexity:  639.1213989257812 
		commit_loss:  0.02676723524928093   codebook loss:  0.10706894099712372   total_loss:  0.1420310139656067
epoch: 22   step: 201   recon_loss: 0.00791831873357296   perplexity:  632.6032104492188 
		commit_loss:  0.025595862418413162   codebook loss:  0.10238344967365265   total_loss:  0.13589763641357422
epoch

epoch: 27   step: 101   recon_loss: 0.007737928535789251   perplexity:  644.4732666015625 
		commit_loss:  0.02489336207509041   codebook loss:  0.09957344830036163   total_loss:  0.13220474123954773
epoch: 27   step: 151   recon_loss: 0.007930881343781948   perplexity:  638.597412109375 
		commit_loss:  0.02539108879864216   codebook loss:  0.10156435519456863   total_loss:  0.13488632440567017
epoch: 27   step: 201   recon_loss: 0.008079132996499538   perplexity:  649.9195556640625 
		commit_loss:  0.02673795446753502   codebook loss:  0.10695181787014008   total_loss:  0.1417689025402069
epoch: 27   step: 251   recon_loss: 0.00820082612335682   perplexity:  641.2607421875 
		commit_loss:  0.026279479265213013   codebook loss:  0.10511791706085205   total_loss:  0.13959822058677673
epoch: 27   step: 301   recon_loss: 0.008126657456159592   perplexity:  641.4671630859375 
		commit_loss:  0.02623705193400383   codebook loss:  0.10494820773601532   total_loss:  0.13931190967559814
epoch

epoch: 32   step: 201   recon_loss: 0.008173206821084023   perplexity:  648.2000122070312 
		commit_loss:  0.027290374040603638   codebook loss:  0.10916149616241455   total_loss:  0.14462506771087646
epoch: 32   step: 251   recon_loss: 0.00721497368067503   perplexity:  620.0120849609375 
		commit_loss:  0.023767128586769104   codebook loss:  0.09506851434707642   total_loss:  0.12605062127113342
epoch: 32   step: 301   recon_loss: 0.007658489979803562   perplexity:  652.1416015625 
		commit_loss:  0.025979500263929367   codebook loss:  0.10391800105571747   total_loss:  0.13755598664283752
epoch: 32   step: 351   recon_loss: 0.008209575898945332   perplexity:  668.3692626953125 
		commit_loss:  0.027482468634843826   codebook loss:  0.1099298745393753   total_loss:  0.14562192559242249
epoch: 33   step: 1   recon_loss: 0.007825877517461777   perplexity:  644.1414794921875 
		commit_loss:  0.02651306800544262   codebook loss:  0.10605227202177048   total_loss:  0.14039121568202972
epo

epoch: 37   step: 301   recon_loss: 0.007290534675121307   perplexity:  656.4033203125 
		commit_loss:  0.02561289817094803   codebook loss:  0.10245159268379211   total_loss:  0.13535502552986145
epoch: 37   step: 351   recon_loss: 0.007722970098257065   perplexity:  659.7182006835938 
		commit_loss:  0.027155006304383278   codebook loss:  0.10862002521753311   total_loss:  0.1434980034828186
epoch: 38   step: 1   recon_loss: 0.0070112599059939384   perplexity:  637.5927124023438 
		commit_loss:  0.02407214045524597   codebook loss:  0.09628856182098389   total_loss:  0.12737196683883667
epoch: 38   step: 51   recon_loss: 0.007578765973448753   perplexity:  649.732421875 
		commit_loss:  0.026304941624403   codebook loss:  0.105219766497612   total_loss:  0.1391034722328186
epoch: 38   step: 101   recon_loss: 0.007442460395395756   perplexity:  650.87353515625 
		commit_loss:  0.024920472875237465   codebook loss:  0.09968189150094986   total_loss:  0.13204482197761536
epoch: 38   ste

epoch: 43   step: 1   recon_loss: 0.007287102751433849   perplexity:  651.2134399414062 
		commit_loss:  0.025362657383084297   codebook loss:  0.10145062953233719   total_loss:  0.13410039246082306
epoch: 43   step: 51   recon_loss: 0.007483947090804577   perplexity:  658.1427612304688 
		commit_loss:  0.0265862625092268   codebook loss:  0.1063450500369072   total_loss:  0.1404152512550354
epoch: 43   step: 101   recon_loss: 0.008311585523188114   perplexity:  664.3536376953125 
		commit_loss:  0.0280187726020813   codebook loss:  0.1120750904083252   total_loss:  0.14840544760227203
epoch: 43   step: 151   recon_loss: 0.007555972784757614   perplexity:  650.3740234375 
		commit_loss:  0.025769811123609543   codebook loss:  0.10307924449443817   total_loss:  0.13640502095222473
epoch: 43   step: 201   recon_loss: 0.007470604032278061   perplexity:  659.9151611328125 
		commit_loss:  0.025891004130244255   codebook loss:  0.10356401652097702   total_loss:  0.1369256228208542
epoch: 43

### Step 5. Evaluate the model

In [None]:
import matplotlib.pyplot as plt

In [None]:
def draw_sample_image(x, postfix):
  
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Visualization of {}".format(postfix))
    plt.imshow(np.transpose(make_grid(x.detach().cpu(), padding=2, normalize=True), (1, 2, 0)))

In [None]:
model.eval()

with torch.no_grad():

    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):

        x = x.to(DEVICE)
        x_hat, commitment_loss, codebook_loss, perplexity = model(x)
 
        print("perplexity: ", perplexity.item(),"commit_loss: ", commitment_loss.item(), "  codebook loss: ", codebook_loss.item())
        break

In [None]:
draw_sample_image(x[:batch_size//2], "Ground-truth images")

In [None]:
draw_sample_image(x_hat[:batch_size//2], "Reconstructed images")

### Step 6. Generate samples via random codes

In [None]:
def draw_random_sample_image(codebook, decoder, indices_shape):
    
    random_indices = torch.floor(torch.rand(indices_shape) * n_embeddings).long().to(DEVICE)
    codes = codebook.retrieve_random_codebook(random_indices)
    x_hat = decoder(codes.to(DEVICE))
    
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Visualization of Random Codes")
    plt.imshow(np.transpose(make_grid(x_hat.detach().cpu(), padding=2, normalize=True), (1, 2, 0)))

In [None]:
draw_random_sample_image(codebook, decoder, indices_shape=(batch_size//2, img_size[0]//4, img_size[1]//4))