In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as datasets
from torchvision.utils import make_grid
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

In [2]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [3]:
class Encoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels  ,  in_channels*2, 5)
        self.conv2 = nn.Conv2d(in_channels*2, out_channels  , 5)
        
        self.activation = nn.ReLU()
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.activation(x)

        return x

In [4]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.conv1 = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=5)
        self.conv2 = nn.ConvTranspose2d(in_channels//2, out_channels, kernel_size=5)

        self.activation = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.activation(x)
        x = self.conv2(x)
        return x

In [5]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_codewords, codewords_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self.num_codewords = num_codewords
        self.codewords_dim = codewords_dim
        self.codewords = nn.Parameter(torch.rand(self.num_codewords, self.codewords_dim),requires_grad=True)

        self.commitment_cost = commitment_cost
        
    def forward(self, inputs):
        # Reshape: B, C, H, W -> B*H*W, C
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        flat_inputs = inputs.view(-1, self.codewords_dim)

        # Calculating distances:
        distances = torch.pow(flat_inputs.unsqueeze(1) - self.codewords.unsqueeze(0),2).sum(2)

        # Argmin:
        encoding_indices = torch.argmin(distances, dim=1)

        # Index from dictionary:
        # quantized[i,j] = self.codewords[encoding_indices[i,j], j]
        quantized = torch.gather(self.codewords,0,encoding_indices.unsqueeze(1).expand(-1, self.codewords_dim)).view(inputs.shape)
        encoding_indices = encoding_indices.view(inputs.shape[:-1])

        # quantization loss
        quantizing_loss = F.mse_loss(quantized.detach(), inputs)
        commitment_loss = F.mse_loss(quantized, inputs.detach())
        loss = quantizing_loss + self.commitment_cost * commitment_loss

        # magic trick to copy gradients from inputs
        quantized = inputs + (quantized - inputs).detach()
        
        #Reshape:
        return quantized.permute(0, 3, 1, 2), encoding_indices, loss

In [25]:
#initialize input (limited to the first 1000 values)
input = F.interpolate(mnist_trainset.data.unsqueeze(1).float(), size=(32,32))

#construction target
target = torch.zeros(len(input),10)
value = mnist_trainset.targets[:len(input)]
for ind in range(len(target)):
    target[ind,value[ind]] = 1
    
#data loader
dataloader = torch.utils.data.DataLoader(input, batch_size = 32)

In [26]:
encoder = Encoder(1, 16)
quantizer = VectorQuantizer(10, 16, 0.15)
decoder = Decoder(16, 1)

criterion = nn.MSELoss()

learning_rate = 0.01

In [None]:
epoch = 0
for x in dataloader:
    
    z_e = encoder(x)
    z_q, codes, q_loss = quantizer(z_e)
    x_hat = decoder(z_q)
    
    #calcul loss
    rec_loss = criterion(x_hat, x)
    loss = rec_loss + q_loss
    print('Epoch {}: loss = {}'.format(epoch,loss.item()))
    
    #backprop
    decoder.zero_grad()
    loss.backward()
    
    for f in encoder.parameters():
        f.data.sub_(f.grad.data * learning_rate)
        
    for f in decoder.parameters():
        f.data.sub_(f.grad.data * learning_rate)
        
    for f in quantizer.parameters():
        f.data.sub_(f.grad.data * learning_rate)
    
    epoch+=1

Epoch 0: loss = 553371270381568.0
Epoch 1: loss = 531466400301056.0
Epoch 2: loss = 510415322742784.0
Epoch 3: loss = inf
Epoch 4: loss = 1.5259345067387967e+32
Epoch 5: loss = 1.4655176168339983e+32
Epoch 6: loss = 1.4074827943641685e+32
Epoch 7: loss = 1.351732938407476e+32
Epoch 8: loss = 1.2982229802093654e+32
Epoch 9: loss = 1.2467859912926644e+32
Epoch 10: loss = 1.197432997060848e+32
Epoch 11: loss = 1.149999003318055e+32
Epoch 12: loss = 1.1044683423856632e+32
Epoch 13: loss = 1.0607354025040712e+32
Epoch 14: loss = 1.0187357721056099e+32
Epoch 15: loss = 9.783769925435951e+31
Epoch 16: loss = 9.39645330420716e+31
Epoch 17: loss = 9.024196997268801e+31
Epoch 18: loss = 8.667004873183496e+31
Epoch 19: loss = 8.323811142948672e+31
Epoch 20: loss = 7.994121114118944e+31
Epoch 21: loss = 7.677599672457114e+31
Epoch 22: loss = 7.373561598808624e+31
Epoch 23: loss = 7.081513167868745e+31
Epoch 24: loss = 6.801129903947491e+31
Epoch 25: loss = 6.5318353912140705e+31
Epoch 26: loss = 6

In [44]:
input = F.interpolate(mnist_testset.data.unsqueeze(1).float(), size=(32,32))
output = net(input)

In [45]:
comp = torch.argmax(output,axis=1) == mnist_testset.targets

In [46]:
torch.sum(comp).item()/len(comp)*100

91.8

In [13]:
print(output.shape)

torch.Size([5000, 2, 24, 24])


In [36]:
torch.zeros(2,3,4).permute(0,2,1).view(-1,16)

RuntimeError: shape '[-1, 16]' is invalid for input of size 24

In [17]:
32*32

1024