In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image


In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [4]:
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

In [5]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

In [6]:
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [7]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

In [8]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))



Epoch[1/15], Step [10/469], Reconst Loss: 34991.5391, KL Div: 2843.4480
Epoch[1/15], Step [20/469], Reconst Loss: 29807.9141, KL Div: 973.3571
Epoch[1/15], Step [30/469], Reconst Loss: 26545.2480, KL Div: 1200.6675
Epoch[1/15], Step [40/469], Reconst Loss: 26738.6016, KL Div: 581.5673
Epoch[1/15], Step [50/469], Reconst Loss: 25813.3027, KL Div: 751.3923
Epoch[1/15], Step [60/469], Reconst Loss: 24863.9102, KL Div: 687.7689
Epoch[1/15], Step [70/469], Reconst Loss: 25941.3809, KL Div: 817.6965
Epoch[1/15], Step [80/469], Reconst Loss: 25080.5742, KL Div: 798.7938
Epoch[1/15], Step [90/469], Reconst Loss: 23489.7676, KL Div: 1120.5010
Epoch[1/15], Step [100/469], Reconst Loss: 23756.9219, KL Div: 1261.2566
Epoch[1/15], Step [110/469], Reconst Loss: 22593.3477, KL Div: 1412.3912
Epoch[1/15], Step [120/469], Reconst Loss: 20466.6328, KL Div: 1511.1536
Epoch[1/15], Step [130/469], Reconst Loss: 19810.3203, KL Div: 1774.1831
Epoch[1/15], Step [140/469], Reconst Loss: 19415.9199, KL Div: 170

Epoch[3/15], Step [220/469], Reconst Loss: 11700.6621, KL Div: 3036.6392
Epoch[3/15], Step [230/469], Reconst Loss: 11607.0498, KL Div: 2963.3672
Epoch[3/15], Step [240/469], Reconst Loss: 11409.3086, KL Div: 3146.4500
Epoch[3/15], Step [250/469], Reconst Loss: 11474.8779, KL Div: 3033.5789
Epoch[3/15], Step [260/469], Reconst Loss: 11187.1836, KL Div: 3004.5249
Epoch[3/15], Step [270/469], Reconst Loss: 11989.0498, KL Div: 3050.4292
Epoch[3/15], Step [280/469], Reconst Loss: 11747.5000, KL Div: 3030.7771
Epoch[3/15], Step [290/469], Reconst Loss: 11646.5078, KL Div: 3204.7668
Epoch[3/15], Step [300/469], Reconst Loss: 11617.2578, KL Div: 3086.5381
Epoch[3/15], Step [310/469], Reconst Loss: 11256.4629, KL Div: 3072.4797
Epoch[3/15], Step [320/469], Reconst Loss: 11495.5156, KL Div: 3047.2241
Epoch[3/15], Step [330/469], Reconst Loss: 11158.7686, KL Div: 3060.0325
Epoch[3/15], Step [340/469], Reconst Loss: 11672.8027, KL Div: 3077.6992
Epoch[3/15], Step [350/469], Reconst Loss: 11379.70

Epoch[5/15], Step [450/469], Reconst Loss: 10484.1807, KL Div: 3114.2612
Epoch[5/15], Step [460/469], Reconst Loss: 10339.2959, KL Div: 3127.2749
Epoch[6/15], Step [10/469], Reconst Loss: 11047.0205, KL Div: 3228.0884
Epoch[6/15], Step [20/469], Reconst Loss: 10319.2314, KL Div: 3122.6074
Epoch[6/15], Step [30/469], Reconst Loss: 10570.1494, KL Div: 3113.0796
Epoch[6/15], Step [40/469], Reconst Loss: 11039.0645, KL Div: 3133.2402
Epoch[6/15], Step [50/469], Reconst Loss: 11135.5137, KL Div: 3271.9707
Epoch[6/15], Step [60/469], Reconst Loss: 10914.9111, KL Div: 3215.7097
Epoch[6/15], Step [70/469], Reconst Loss: 11176.8496, KL Div: 3185.8130
Epoch[6/15], Step [80/469], Reconst Loss: 11107.1494, KL Div: 3204.2788
Epoch[6/15], Step [90/469], Reconst Loss: 10778.5537, KL Div: 3228.1541
Epoch[6/15], Step [100/469], Reconst Loss: 10815.6396, KL Div: 3082.5100
Epoch[6/15], Step [110/469], Reconst Loss: 10977.6348, KL Div: 3303.9810
Epoch[6/15], Step [120/469], Reconst Loss: 10905.3643, KL Di

Epoch[8/15], Step [220/469], Reconst Loss: 10122.5820, KL Div: 3241.7998
Epoch[8/15], Step [230/469], Reconst Loss: 10303.4980, KL Div: 3159.5046
Epoch[8/15], Step [240/469], Reconst Loss: 10790.2549, KL Div: 3191.4248
Epoch[8/15], Step [250/469], Reconst Loss: 10921.6631, KL Div: 3274.9714
Epoch[8/15], Step [260/469], Reconst Loss: 10137.8145, KL Div: 3261.5862
Epoch[8/15], Step [270/469], Reconst Loss: 10782.9209, KL Div: 3296.4204
Epoch[8/15], Step [280/469], Reconst Loss: 10126.2852, KL Div: 3223.1914
Epoch[8/15], Step [290/469], Reconst Loss: 10455.3643, KL Div: 3170.4868
Epoch[8/15], Step [300/469], Reconst Loss: 10898.7881, KL Div: 3312.6621
Epoch[8/15], Step [310/469], Reconst Loss: 10866.1719, KL Div: 3226.1077
Epoch[8/15], Step [320/469], Reconst Loss: 10617.4502, KL Div: 3329.8540
Epoch[8/15], Step [330/469], Reconst Loss: 10496.7822, KL Div: 3263.2429
Epoch[8/15], Step [340/469], Reconst Loss: 10612.5107, KL Div: 3198.1318
Epoch[8/15], Step [350/469], Reconst Loss: 10066.70

Epoch[10/15], Step [430/469], Reconst Loss: 9928.5723, KL Div: 3269.9419
Epoch[10/15], Step [440/469], Reconst Loss: 10598.8691, KL Div: 3218.8525
Epoch[10/15], Step [450/469], Reconst Loss: 10895.0254, KL Div: 3381.2458
Epoch[10/15], Step [460/469], Reconst Loss: 10846.8242, KL Div: 3215.2686
Epoch[11/15], Step [10/469], Reconst Loss: 10348.5420, KL Div: 3223.1887
Epoch[11/15], Step [20/469], Reconst Loss: 10445.0859, KL Div: 3301.6973
Epoch[11/15], Step [30/469], Reconst Loss: 10353.0098, KL Div: 3119.5469
Epoch[11/15], Step [40/469], Reconst Loss: 10275.1348, KL Div: 3173.6807
Epoch[11/15], Step [50/469], Reconst Loss: 10250.1055, KL Div: 3250.5698
Epoch[11/15], Step [60/469], Reconst Loss: 10215.8291, KL Div: 3189.8086
Epoch[11/15], Step [70/469], Reconst Loss: 10269.6650, KL Div: 3242.7698
Epoch[11/15], Step [80/469], Reconst Loss: 10584.1895, KL Div: 3273.4136
Epoch[11/15], Step [90/469], Reconst Loss: 10503.8857, KL Div: 3299.7249
Epoch[11/15], Step [100/469], Reconst Loss: 1053

Epoch[13/15], Step [190/469], Reconst Loss: 10265.4980, KL Div: 3310.5820
Epoch[13/15], Step [200/469], Reconst Loss: 10384.3252, KL Div: 3334.6938
Epoch[13/15], Step [210/469], Reconst Loss: 10032.3867, KL Div: 3248.2815
Epoch[13/15], Step [220/469], Reconst Loss: 9954.8984, KL Div: 3203.5073
Epoch[13/15], Step [230/469], Reconst Loss: 10454.3906, KL Div: 3413.2659
Epoch[13/15], Step [240/469], Reconst Loss: 10353.9238, KL Div: 3304.2456
Epoch[13/15], Step [250/469], Reconst Loss: 10054.6709, KL Div: 3200.4141
Epoch[13/15], Step [260/469], Reconst Loss: 10474.6572, KL Div: 3265.6951
Epoch[13/15], Step [270/469], Reconst Loss: 10365.7607, KL Div: 3293.8564
Epoch[13/15], Step [280/469], Reconst Loss: 10447.0518, KL Div: 3205.6377
Epoch[13/15], Step [290/469], Reconst Loss: 10296.3379, KL Div: 3268.6201
Epoch[13/15], Step [300/469], Reconst Loss: 10201.3408, KL Div: 3310.9790
Epoch[13/15], Step [310/469], Reconst Loss: 9929.4502, KL Div: 3173.7883
Epoch[13/15], Step [320/469], Reconst Lo

Epoch[15/15], Step [390/469], Reconst Loss: 10398.4590, KL Div: 3281.9751
Epoch[15/15], Step [400/469], Reconst Loss: 10013.5781, KL Div: 3154.8486
Epoch[15/15], Step [410/469], Reconst Loss: 10103.8857, KL Div: 3260.9517
Epoch[15/15], Step [420/469], Reconst Loss: 10115.5000, KL Div: 3252.0024
Epoch[15/15], Step [430/469], Reconst Loss: 9815.9463, KL Div: 3210.5952
Epoch[15/15], Step [440/469], Reconst Loss: 9961.9756, KL Div: 3264.4841
Epoch[15/15], Step [450/469], Reconst Loss: 10143.5977, KL Div: 3259.7280
Epoch[15/15], Step [460/469], Reconst Loss: 10205.9180, KL Div: 3162.0981
