In [30]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
import numpy as np
from matplotlib import pyplot as plt
import math, os
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage

In [2]:
z_dim = 3
epochs = 40

In [3]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)

In [4]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=200, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,download=True, transform=transforms.ToTensor()),
    batch_size=200, shuffle=True, **kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!




In [45]:
class encoder(nn.Module):
  def __init__(self, input_dim, output_dim):
    super(encoder, self).__init__()

    self.conv1 = nn.Conv2d(1, 64, (4,4), 2, 1)
    self.conv2 = nn.Conv2d(64, 128, (4,4), 2, 1)
    self.linear1 = nn.Linear(input_dim, 1024)
    self.linear2 = nn.Linear(self.linear1.out_features, output_dim)

  def forward(self, x):
    
    x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
    x = F.leaky_relu(self.conv2(x), negative_slope=0.1)
    num_examples = x.shape[0]
    x = x.view(num_examples, -1)
    x_linear = F.leaky_relu(self.linear1(x), negative_slope=0.1)
    x_linear = self.linear2(x_linear)

    return x_linear

In [46]:
enc = encoder(128*7*7, z_dim).to(device)

In [47]:
from torchsummary import summary
summary(enc, (1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           1,088
            Conv2d-2            [-1, 128, 7, 7]         131,200
            Linear-3                 [-1, 1024]       6,423,552
            Linear-4                    [-1, 3]           3,075
Total params: 6,558,915
Trainable params: 6,558,915
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.15
Params size (MB): 25.02
Estimated Total Size (MB): 25.17
----------------------------------------------------------------


In [48]:
class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()

    self.linear1 = nn.Linear(z_dim, 1024)
    self.linear2 = nn.Linear(1024, 128*7*7)
    self.transpose_conv1 = nn.ConvTranspose2d(128, 64, (4,4), 2 ,1)
    self.transpose_conv2 = nn.ConvTranspose2d(64, 1, (4,4), 2, 1)

  def forward(self, z):
    
    z = F.relu(self.linear1(z))
    z = F.relu(self.linear2(z))
    num_examples = z.shape[0]
    z = z.view(num_examples, -1, 7, 7)
    x = F.relu(self.transpose_conv1(z))
    x = F.sigmoid(self.transpose_conv2(x))

    return x 

In [49]:
dec = decoder().to(device)

In [50]:
from torchsummary import summary
summary(dec, (1,3))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 1, 1024]           4,096
            Linear-2              [-1, 1, 6272]       6,428,800
   ConvTranspose2d-3           [-1, 64, 14, 14]         131,136
   ConvTranspose2d-4            [-1, 1, 28, 28]           1,025
Total params: 6,565,057
Trainable params: 6,565,057
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.16
Params size (MB): 25.04
Estimated Total Size (MB): 25.20
----------------------------------------------------------------




In [51]:
def compute_kernel(x, y):
    x_size = x.shape[0]
    y_size = y.shape[0]
    dim = x.shape[1]

    tiled_x = x.view(x_size,1,dim).repeat(1, y_size,1)
    tiled_y = y.view(1,y_size,dim).repeat(x_size, 1,1)

    return torch.exp(-torch.mean((tiled_x - tiled_y)**2,dim=2)/dim*1.0)


def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)

In [52]:
optim_encoder = optim.Adam(enc.parameters(), lr=1e-3)
optim_decoder = optim.Adam(dec.parameters(), lr=1e-3)

In [53]:
for epoch in range(epochs):
  for bidx, (train_data, _) in enumerate(train_loader):

    x_pd = Variable(train_data.to(device))

    z_pz = Variable(torch.randn(len(x_pd), z_dim))

    x_pd = x_pd.to(device)
    z_pz = z_pz.to(device)

    optim_encoder.zero_grad()
    optim_decoder.zero_grad()

    z_qzx = enc(x_pd)
    x_pxz = dec(z_qzx)

    mmd_loss = compute_mmd(z_qzx, z_pz)
    rest_loss = torch.mean((x_pxz - x_pd)**2)

    total_loss = mmd_loss + rest_loss

    total_loss.backward()
    
    optim_encoder.step()
    optim_decoder.step()

  test_z = Variable(torch.randn(100, z_dim).to(device))
  test_x_rec = dec(test_z)
  test_x_rec =test_x_rec.view(100,28,28)
  save_image(test_x_rec.view(100, 1, 28, 28), '/samples/sample_' + str(epoch) + '.png')
 
  print('[%d/%d] Total Loss: %0.3f' % (epoch, epochs,  total_loss))



[0/40] Total Loss: 0.161
[1/40] Total Loss: 0.050
[2/40] Total Loss: 0.049
[3/40] Total Loss: 0.044
[4/40] Total Loss: 0.040
[5/40] Total Loss: 0.040
[6/40] Total Loss: 0.036
[7/40] Total Loss: 0.039
[8/40] Total Loss: 0.039
[9/40] Total Loss: 0.036
[10/40] Total Loss: 0.038
[11/40] Total Loss: 0.037
[12/40] Total Loss: 0.045
[13/40] Total Loss: 0.042
[14/40] Total Loss: 0.038
[15/40] Total Loss: 0.039
[16/40] Total Loss: 0.037
[17/40] Total Loss: 0.036
[18/40] Total Loss: 0.035
[19/40] Total Loss: 0.036
[20/40] Total Loss: 0.035
[21/40] Total Loss: 0.033
[22/40] Total Loss: 0.036
[23/40] Total Loss: 0.034
[24/40] Total Loss: 0.034
[25/40] Total Loss: 0.034
[26/40] Total Loss: 0.035
[27/40] Total Loss: 0.033
[28/40] Total Loss: 0.032
[29/40] Total Loss: 0.035
[30/40] Total Loss: 0.036
[31/40] Total Loss: 0.031
[32/40] Total Loss: 0.033
[33/40] Total Loss: 0.034
[34/40] Total Loss: 0.042
[35/40] Total Loss: 0.031
[36/40] Total Loss: 0.031
[37/40] Total Loss: 0.033
[38/40] Total Loss: 0.

In [60]:
# GIF of reconstructed images vs epoch 

import os 
from PIL import Image 

images_dir = '/samples'
image_list = []
for i in range(len(os.listdir(images_dir))):
  temp_img = Image.open(images_dir + '/sample_' + str(i) + '.png' )
  image_list.append(temp_img)

image_list[0].save('/samples/reconstructed_images.gif', save_all=True, append_images=image_list[1:])

In [56]:
i

40