In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms as tf
from torchvision.utils import save_image

In [6]:
# Define hyperparameters
image_size=784
hidden_dim = 400
latent_dim=20
batch_size=128
epochs=10


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST Dataset
train_dataset=torchvision.datasets.MNIST(root='../data',
                                         train=True,
                                         transform=tf.ToTensor(),
                                         download=True)
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,
                                         batch_size=batch_size,shuffle=True)

test_dataset=torchvision.datasets.MNIST(root='../data',
                                         train=False,
                                         transform=tf.ToTensor(),)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

#Create directory to save the reconstructed and sampled images (if directory not present)
sample_dir= 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

![vae](https://user-images.githubusercontent.com/30661597/78418103-a2047200-766b-11ea-8205-c7e5712715f4.png)

In [7]:
# VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(image_size, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_size)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc2_mean(h)
        log_var = self.fc2_logvar(h)
        return mu, log_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out

    def forward(self, x):
        # x: (batch_size, 1, 28,28) --> (batch_size, 784)
        mu, logvar = self.encode(x.view(-1, image_size))
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

# Define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

### $Loss = -E[\log P(X | z)]+D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]$

#### $D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]=\frac{1}{2} \sum_{k}\left(\exp (\Sigma(X))+\mu^{2}(X)-1-\Sigma(X)\right)$

In [10]:
# Define loss
def loss_function(reconstructed_image,original_image,mu,logvar):
    bce=F.binary_cross_entropy(reconstructed_image,original_image.view(-1,784),reduction='sum')
    kld=0.5*torch.sum(logvar.exp()+mu.pow(2) -1 - logvar)
    #####################################################
    # logvar, exp: (batch_size,20)
    # kld = 0.5 * torch.sum( logvar.exp() + mu.pow(2) -1 -logvar,1) # (batch_-size)
    # kld_sum = torch.sum(kld)
    return bce + kld

# train function
def train(epoch):
    model.train()
    train_loss=0
    for i, (images,_) in enumerate(train_loader):
        images = images.to(device)
        reconstructed,mu,logvar=model(images)
        loss = loss_function(reconstructed,images,mu,logvar)
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if i % 100 ==0:
            print('Train Epoch {}  [Batch: {}/{}]\tloss: {:3f}'.format(epoch,i,len(train_loader),loss.item()/len(images)))
    print('=====> Epoch {}/{}, Average Loss: {:.3f}'.format(epoch,epochs,train_loss/len(train_loader.dataset)))

#Test funciton
def test(epoch):
    model.eval()
    test_loss=0
    with torch.no_grad():
        for batch_idx, (images,_) in enumerate(test_loader):
            images = images.to(device)
            reconstructed,mu,logvar = model(images)
            test_loss += loss_function(reconstructed,images,mu,logvar).item()
            if batch_idx ==0:
                comparison = torch.cat([images[:5],reconstructed.view(batch_size,1,28,28)[:5]])
                save_image(comparison.cpu(),'results/reconstruction_'+str(epoch)+'.png',nrow=5)

    print('=====> Average Test Loss: {:.3f}'.format(test_loss/len(test_loader.dataset)))

In [9]:
# main function
for epoch in range(1,epochs+1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # Get rid of the encoder and sample z from the fuassian distribution and feed it to the decoder to generate samples
        sample=torch.randn(64,20).to(device)
        generated = model.decode(sample).cpu()
        save_image(generated.view(64,1,28,28),'results/sample_'+str(epoch)+'.png')

Train Epoch 1  [Batch0/469]	loss: 548.405579
Train Epoch 1  [Batch100/469]	loss: 187.809143
Train Epoch 1  [Batch200/469]	loss: 150.541779
Train Epoch 1  [Batch300/469]	loss: 138.330200
Train Epoch 1  [Batch400/469]	loss: 130.807556
=====> Epoch 1, Average Loss: 165.204
=====> Average Test Loss: 128.362
Train Epoch 2  [Batch0/469]	loss: 126.273087
Train Epoch 2  [Batch100/469]	loss: 124.899948
Train Epoch 2  [Batch200/469]	loss: 118.156609
Train Epoch 2  [Batch300/469]	loss: 119.573822
Train Epoch 2  [Batch400/469]	loss: 119.003311
=====> Epoch 2, Average Loss: 122.165
=====> Average Test Loss: 116.622
Train Epoch 3  [Batch0/469]	loss: 119.288666
Train Epoch 3  [Batch100/469]	loss: 113.637123
Train Epoch 3  [Batch200/469]	loss: 115.232651
Train Epoch 3  [Batch300/469]	loss: 114.510597
Train Epoch 3  [Batch400/469]	loss: 111.098389
=====> Epoch 3, Average Loss: 114.907
=====> Average Test Loss: 112.358
Train Epoch 4  [Batch0/469]	loss: 114.943321
Train Epoch 4  [Batch100/469]	loss: 115.