In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
from torchvision.utils import make_grid
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

In [2]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [3]:
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels  ,  in_channels*2, 5)
        self.conv2 = nn.Conv2d(in_channels*2, out_channels  , 5)
        
        self.activation = nn.ReLU()
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.activation(x)

        return x

In [4]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=5)
        self.conv2 = nn.ConvTranspose2d(in_channels//2, out_channels, kernel_size=5)

        self.activation = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.activation(x)
        x = self.conv2(x)
        return x

In [5]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codewords, codewords_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self.num_codewords = num_codewords
        self.codewords_dim = codewords_dim
        self.codewords = nn.Parameter(torch.rand(self.num_codewords, self.codewords_dim),requires_grad=True)

        self.commitment_cost = commitment_cost
        
    def forward(self, inputs):
        # Reshape: B, C, H, W -> B*H*W, C
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        inputs_shape = inputs.size()
        flat_inputs = inputs.view(-1, self.codewords_dim)

        # Calculating distances:
        distances = torch.pow(flat_inputs.unsqueeze(1) - self.codewords.unsqueeze(0),2).sum(2)
        
        # Argmin:
        encoding_indices = torch.argmin(distances, dim=1)

        # quantized[i,j] = self.codewords[encoding_indices[i,j], j]
        quantized = torch.gather(self.codewords,0,encoding_indices.unsqueeze(1).expand(-1, self.codewords_dim)).view(inputs_shape)
        encoding_indices = encoding_indices.view(inputs_shape[:-1])

        # quantization loss
        quantizing_loss = F.mse_loss(quantized.detach(), inputs)
        commitment_loss = F.mse_loss(quantized, inputs.detach())
        loss = quantizing_loss + self.commitment_cost * commitment_loss

        # magic trick to copy gradients from inputs
        quantized = inputs + (quantized - inputs).detach()
        
        #Reshape:
        return quantized.permute(0, 3, 1, 2), encoding_indices, loss

In [6]:
#initialize input (limited to the first 1000 values)
input = F.interpolate(mnist_trainset.data.unsqueeze(1).float(), size=(32,32))[:1000]

#construction target
target = torch.zeros(len(input),10)
value = mnist_trainset.targets[:len(input)]
for ind in range(len(target)):
    target[ind,value[ind]] = 1

In [7]:
encoder = Encoder(1, 16)
decoder = Decoder(16, 1)
quantizer = VectorQuantizer(10, 16, 0.15)

In [15]:
for ind in range(5):
    #forward
    x = input
    
    z_e = encoder(x)
    z_q, codes, q_loss = quantizer(z_e)
    x_hat = decoder(z_q)
    
    #calcul loss
    rec_loss = F.mse_loss(x_hat, x)
    loss = rec_loss + q_loss
    print('{}: {}'.format(ind,loss.item()))
    
    #backprop
    decoder.zero_grad()
    
    loss.backward()
    
    learning_rate = 0.02
    
    for f in encoder.parameters():
        f.data.sub_(f.grad.data * learning_rate)

0: 6655.23486328125
1: 6655.23486328125
2: 6655.23486328125
3: 6655.23486328125
4: 6655.23486328125
5: 6655.23486328125
6: 6655.23486328125
7: 6655.23486328125
8: 6655.23486328125
9: 6655.23486328125


In [44]:
input = F.interpolate(mnist_testset.data.unsqueeze(1).float(), size=(32,32))
output = net(input)

In [45]:
comp = torch.argmax(output,axis=1) == mnist_testset.targets

In [46]:
torch.sum(comp).item()/len(comp)*100

91.8

In [13]:
print(output.shape)

torch.Size([5000, 2, 24, 24])
