In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
from torchvision.utils import make_grid
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [3]:
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels  ,  in_channels*2, 5)
        self.conv2 = nn.Conv2d(in_channels*2, out_channels  , 5)
        
        self.activation = nn.ReLU()
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.activation(x)

        return x

In [4]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codewords, codewords_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self.num_codewords = num_codewords
        self.codewords_dim = codewords_dim
        self.codewords = nn.Parameter(torch.rand(self.num_codewords, self.codewords_dim),requires_grad=True)

        self.commitment_cost = commitment_cost
        
    def forward(self, inputs):
        # Reshape: B, C, H, W -> B*H*W, C
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        flat_inputs = inputs.view(-1, self.codewords_dim)

        # Calculating distances:
        distances = torch.pow(flat_inputs.unsqueeze(1) - self.codewords.unsqueeze(0),2).sum(2)

        # Argmin:
        encoding_indices = torch.argmin(distances, dim=1)

        # Index from dictionary:
        # quantized[i,j] = self.codewords[encoding_indices[i,j], j]
        quantized = torch.gather(self.codewords,0,encoding_indices.unsqueeze(1).expand(-1, self.codewords_dim)).view(inputs.shape)
        encoding_indices = encoding_indices.view(inputs.shape[:-1])

        # quantization loss
        quantizing_loss = F.mse_loss(quantized.detach(), inputs)
        commitment_loss = F.mse_loss(quantized, inputs.detach())
        loss = quantizing_loss + self.commitment_cost * commitment_loss

        # magic trick to copy gradients from inputs
        quantized = inputs + (quantized - inputs).detach()
        
        #Reshape:
        return quantized.permute(0, 3, 1, 2), encoding_indices, loss

In [5]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=5)
        self.conv2 = nn.ConvTranspose2d(in_channels//2, out_channels, kernel_size=5)

        self.activation = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.activation(x)
        x = self.conv2(x)
        return x

In [6]:
class VQVAE(nn.Module):
    def __init__(self,x_dim,z_dim,num_codewords,commitment_cost):
        super().__init__()
        self.encoder = Encoder(x_dim,z_dim)
        self.quantizer = VectorQuantizer(num_codewords, z_dim, commitment_cost)
        self.decoder = Decoder(z_dim, x_dim)

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, codes, q_loss = self.quantizer(z_e)
        x_hat = self.decoder(z_q)
        
        return x_hat, codes, q_loss

In [7]:
#initialize input (limited to the first 1000 values)
input = mnist_trainset.data.unsqueeze(1).float()[:10000]

#construction target
target = torch.zeros(len(input),10)
value = mnist_trainset.targets[:len(input)]
for ind in range(len(target)):
    target[ind,value[ind]] = 1
    
#data loader
dataloader = torch.utils.data.DataLoader(input, batch_size = 32)

In [8]:
learning_rate = 0.01

model = VQVAE(1,16,10,0.15).to(device)

optimizer = torch.optim.Adam(model.parameters() , lr=learning_rate, amsgrad=True)

In [9]:
for epoch in range(100):
    for x in dataloader:
        x = x.to(device)
        x_hat, codes, q_loss = model.forward(x)
        
        #calcul loss
        rec_loss = F.mse_loss(x_hat, x)
        loss = rec_loss + q_loss

        #backprop
        model.zero_grad()
        loss.backward()
        
        #optimizer.step()
        
        for f in (model.parameters()):
            f.data.sub_(f.grad.data * learning_rate)
        

    print('Epoch {}: loss = {}'.format(epoch,loss.item()))

Epoch 0: loss = 6335.32568359375
Epoch 1: loss = 6334.9462890625
Epoch 2: loss = 6334.92919921875
Epoch 3: loss = 6334.91552734375
Epoch 4: loss = 6334.9033203125
Epoch 5: loss = 6334.89208984375
Epoch 6: loss = 6334.88232421875
Epoch 7: loss = 6334.8740234375
Epoch 8: loss = 6334.8662109375
Epoch 9: loss = 6334.859375
Epoch 10: loss = 6334.853515625
Epoch 11: loss = 6334.84814453125
Epoch 12: loss = 6334.84326171875
Epoch 13: loss = 6334.8388671875
Epoch 14: loss = 6334.8349609375
Epoch 15: loss = 6334.83154296875
Epoch 16: loss = 6334.82861328125
Epoch 17: loss = 6334.826171875
Epoch 18: loss = 6334.82373046875
Epoch 19: loss = 6334.82177734375
Epoch 20: loss = 6334.81982421875
Epoch 21: loss = 6334.81787109375
Epoch 22: loss = 6334.81640625
Epoch 23: loss = 6334.8154296875
Epoch 24: loss = 6334.81396484375
Epoch 25: loss = 6334.81298828125


KeyboardInterrupt: ignored