In [1]:
import os
import sys
sys.path.append('./')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)
print("data : " + str(len(dataset)) + ' images')


# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)


# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(10, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

data : 60000 images




Epoch[1/15], Step [10/469], Reconst Loss: 35318.1641, KL Div: 3802.9487
Epoch[1/15], Step [20/469], Reconst Loss: 29716.2695, KL Div: 1105.4612
Epoch[1/15], Step [30/469], Reconst Loss: 26501.5625, KL Div: 1300.5052
Epoch[1/15], Step [40/469], Reconst Loss: 27156.9336, KL Div: 674.8234
Epoch[1/15], Step [50/469], Reconst Loss: 26209.9609, KL Div: 813.1757
Epoch[1/15], Step [60/469], Reconst Loss: 25954.8965, KL Div: 866.1167
Epoch[1/15], Step [70/469], Reconst Loss: 25533.9551, KL Div: 851.5728
Epoch[1/15], Step [80/469], Reconst Loss: 24946.6641, KL Div: 927.5524
Epoch[1/15], Step [90/469], Reconst Loss: 22866.5234, KL Div: 1067.9318
Epoch[1/15], Step [100/469], Reconst Loss: 21587.6445, KL Div: 1395.4412
Epoch[1/15], Step [110/469], Reconst Loss: 22389.6934, KL Div: 1339.3346
Epoch[1/15], Step [120/469], Reconst Loss: 21160.7891, KL Div: 1617.3932
Epoch[1/15], Step [130/469], Reconst Loss: 19963.0586, KL Div: 1761.6270
Epoch[1/15], Step [140/469], Reconst Loss: 19130.1074, KL Div: 17

Epoch[3/15], Step [230/469], Reconst Loss: 11413.7021, KL Div: 3069.2979
Epoch[3/15], Step [240/469], Reconst Loss: 11618.3340, KL Div: 3177.3115
Epoch[3/15], Step [250/469], Reconst Loss: 11701.7666, KL Div: 3117.1934
Epoch[3/15], Step [260/469], Reconst Loss: 11297.5850, KL Div: 3178.6416
Epoch[3/15], Step [270/469], Reconst Loss: 11520.6934, KL Div: 3109.1050
Epoch[3/15], Step [280/469], Reconst Loss: 11804.8027, KL Div: 3077.6218
Epoch[3/15], Step [290/469], Reconst Loss: 12029.3457, KL Div: 3201.5845
Epoch[3/15], Step [300/469], Reconst Loss: 11657.3389, KL Div: 3086.0437
Epoch[3/15], Step [310/469], Reconst Loss: 11570.6328, KL Div: 3121.7651
Epoch[3/15], Step [320/469], Reconst Loss: 11508.0430, KL Div: 3056.6558
Epoch[3/15], Step [330/469], Reconst Loss: 11467.7314, KL Div: 3084.4023
Epoch[3/15], Step [340/469], Reconst Loss: 11223.3740, KL Div: 3099.8782
Epoch[3/15], Step [350/469], Reconst Loss: 11285.4277, KL Div: 3089.9790
Epoch[3/15], Step [360/469], Reconst Loss: 11818.23

Epoch[5/15], Step [450/469], Reconst Loss: 10269.6680, KL Div: 3262.7031
Epoch[5/15], Step [460/469], Reconst Loss: 10856.7510, KL Div: 3211.4934
Epoch[6/15], Step [10/469], Reconst Loss: 10605.3652, KL Div: 3067.2959
Epoch[6/15], Step [20/469], Reconst Loss: 10681.3408, KL Div: 3253.9158
Epoch[6/15], Step [30/469], Reconst Loss: 11003.3672, KL Div: 3251.9829
Epoch[6/15], Step [40/469], Reconst Loss: 10562.5146, KL Div: 3280.1973
Epoch[6/15], Step [50/469], Reconst Loss: 10597.6934, KL Div: 3212.3723
Epoch[6/15], Step [60/469], Reconst Loss: 10827.9902, KL Div: 3145.1206
Epoch[6/15], Step [70/469], Reconst Loss: 10770.3965, KL Div: 3219.1982
Epoch[6/15], Step [80/469], Reconst Loss: 10681.6035, KL Div: 3099.6875
Epoch[6/15], Step [90/469], Reconst Loss: 10953.9717, KL Div: 3238.1262
Epoch[6/15], Step [100/469], Reconst Loss: 10937.5527, KL Div: 3342.9473
Epoch[6/15], Step [110/469], Reconst Loss: 10658.8027, KL Div: 3300.8662
Epoch[6/15], Step [120/469], Reconst Loss: 11305.5547, KL Di

Epoch[8/15], Step [210/469], Reconst Loss: 10223.4180, KL Div: 3145.8887
Epoch[8/15], Step [220/469], Reconst Loss: 10435.1963, KL Div: 3190.4531
Epoch[8/15], Step [230/469], Reconst Loss: 10607.6113, KL Div: 3269.5840
Epoch[8/15], Step [240/469], Reconst Loss: 10774.9365, KL Div: 3278.8530
Epoch[8/15], Step [250/469], Reconst Loss: 10547.1064, KL Div: 3196.9976
Epoch[8/15], Step [260/469], Reconst Loss: 10654.3193, KL Div: 3302.2734
Epoch[8/15], Step [270/469], Reconst Loss: 10829.7686, KL Div: 3210.8096
Epoch[8/15], Step [280/469], Reconst Loss: 10657.5156, KL Div: 3170.4146
Epoch[8/15], Step [290/469], Reconst Loss: 10590.9277, KL Div: 3230.7173
Epoch[8/15], Step [300/469], Reconst Loss: 10153.7168, KL Div: 3153.4141
Epoch[8/15], Step [310/469], Reconst Loss: 10239.4766, KL Div: 3212.6157
Epoch[8/15], Step [320/469], Reconst Loss: 10764.2900, KL Div: 3394.9800
Epoch[8/15], Step [330/469], Reconst Loss: 10616.6768, KL Div: 3152.8242
Epoch[8/15], Step [340/469], Reconst Loss: 10405.51

Epoch[10/15], Step [430/469], Reconst Loss: 10603.2900, KL Div: 3212.3818
Epoch[10/15], Step [440/469], Reconst Loss: 9917.1660, KL Div: 3141.9326
Epoch[10/15], Step [450/469], Reconst Loss: 10478.7051, KL Div: 3236.8784
Epoch[10/15], Step [460/469], Reconst Loss: 9953.6016, KL Div: 3262.5095
Epoch[11/15], Step [10/469], Reconst Loss: 10306.1475, KL Div: 3179.4863
Epoch[11/15], Step [20/469], Reconst Loss: 10404.5166, KL Div: 3157.5874
Epoch[11/15], Step [30/469], Reconst Loss: 9834.0117, KL Div: 3228.7622
Epoch[11/15], Step [40/469], Reconst Loss: 10007.8799, KL Div: 3089.1714
Epoch[11/15], Step [50/469], Reconst Loss: 10320.7656, KL Div: 3211.1165
Epoch[11/15], Step [60/469], Reconst Loss: 10193.2383, KL Div: 3179.3208
Epoch[11/15], Step [70/469], Reconst Loss: 10420.4912, KL Div: 3271.5549
Epoch[11/15], Step [80/469], Reconst Loss: 10495.7949, KL Div: 3336.1421
Epoch[11/15], Step [90/469], Reconst Loss: 10337.5156, KL Div: 3209.5127
Epoch[11/15], Step [100/469], Reconst Loss: 10098.

Epoch[13/15], Step [170/469], Reconst Loss: 10203.2715, KL Div: 3261.6396
Epoch[13/15], Step [180/469], Reconst Loss: 10144.8242, KL Div: 3158.9033
Epoch[13/15], Step [190/469], Reconst Loss: 10921.6016, KL Div: 3313.9661
Epoch[13/15], Step [200/469], Reconst Loss: 10501.3506, KL Div: 3225.2344
Epoch[13/15], Step [210/469], Reconst Loss: 10507.0234, KL Div: 3349.1211
Epoch[13/15], Step [220/469], Reconst Loss: 10586.5664, KL Div: 3330.4104
Epoch[13/15], Step [230/469], Reconst Loss: 10435.7363, KL Div: 3299.5688
Epoch[13/15], Step [240/469], Reconst Loss: 10150.3145, KL Div: 3199.0957
Epoch[13/15], Step [250/469], Reconst Loss: 10203.2383, KL Div: 3119.7700
Epoch[13/15], Step [260/469], Reconst Loss: 10131.3535, KL Div: 3319.7878
Epoch[13/15], Step [270/469], Reconst Loss: 10270.8330, KL Div: 3176.0291
Epoch[13/15], Step [280/469], Reconst Loss: 10661.6328, KL Div: 3183.1987
Epoch[13/15], Step [290/469], Reconst Loss: 10101.1094, KL Div: 3351.9463
Epoch[13/15], Step [300/469], Reconst 

Epoch[15/15], Step [370/469], Reconst Loss: 9842.3027, KL Div: 3237.9448
Epoch[15/15], Step [380/469], Reconst Loss: 10065.6494, KL Div: 3364.7686
Epoch[15/15], Step [390/469], Reconst Loss: 10405.9160, KL Div: 3285.6794
Epoch[15/15], Step [400/469], Reconst Loss: 9962.2441, KL Div: 3145.6050
Epoch[15/15], Step [410/469], Reconst Loss: 9827.7695, KL Div: 3243.9824
Epoch[15/15], Step [420/469], Reconst Loss: 10238.7061, KL Div: 3214.7231
Epoch[15/15], Step [430/469], Reconst Loss: 9938.4424, KL Div: 3316.8372
Epoch[15/15], Step [440/469], Reconst Loss: 9795.4902, KL Div: 3232.6538
Epoch[15/15], Step [450/469], Reconst Loss: 10114.7959, KL Div: 3274.0791
Epoch[15/15], Step [460/469], Reconst Loss: 9888.8770, KL Div: 3240.9622
