# Variational Autoencoder

In [None]:
import sys,os
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.autograd import Variable

sys.path.append(os.pardir)
from utils import *

### Settings

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
image_size = 28
hidden_size = 400
latent_size = 2

learning_rate = 0.001
num_epochs = 5

In [None]:
train_data = torchvision.datasets.MNIST(root='./../data/MNIST/', train=True, transform=transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST(root='./../data/MNIST/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

### Model

In [None]:
class VEncoder(nn.Module):
    
    def __init__(self):
        super(VEncoder, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(image_size**2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )
        
        self.i2mu = nn.Linear(hidden_size, latent_size)
        self.i2log_var = nn.Linear(hidden_size, latent_size)
        
    def forward(self, input):        
        
        input = self.fc(input.view(batch_size, -1))        
        mu = self.i2mu(input)
        log_var = self.i2log_var(input)
        
        reparam = self.reparameterize(mu, log_var)
        
        return mu, log_var, reparam
        
    def reparameterize(self, mu, log_var):
        
        std = torch.exp(log_var / 2)
        eps = torch.randn(std.size()).to(device)
        
        return mu + eps * std
    
encoder = VEncoder().to(device)

In [None]:
class VDecoder(nn.Module):
    
    def __init__(self):
        super(VDecoder, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, image_size**2),
            nn.Sigmoid()
        )
        
    def forward(self, input):
                
        output = self.fc(input)
        output = output.view(batch_size, 1, image_size, image_size)
        
        return output

decoder = VDecoder().to(device)

### Loss Function & Optimizer

In [None]:
Reconstuct_Error = nn.BCELoss(reduction='sum')

def criterion(input, output, mu, log_var):
    
    Reconst_loss = Reconstuct_Error(output, input)
    Regularization =  -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())

    return Reconst_loss + Regularization

parameters = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(parameters, lr=learning_rate)

### Train

In [None]:
all_losses = []

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        
        input = Variable(images).to(device)
        
        mu, log_var, reparam = encoder(input)
        output = decoder(reparam)
        
        loss = criterion(input, output, mu, log_var)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], loss [{:.4f}]'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            all_losses.append(loss.item())
            
    save_images(output, './images/VAE','VAE_{}.png'.format(epoch+1))
    
torch.save(encoder.state_dict(), './models/VAE_Encoder.ckpt')
torch.save(decoder.state_dict(), './models/VAE_Decoder.ckpt')

In [None]:
drawLoss({'VAE':all_losses})

### Test

In [None]:
for i , (images, _) in enumerate(test_loader):
    
    input = Variable(images).to(device)
    
    _, _, reparam = encoder(input)
    output = decoder(reparam)
    
    test = output.cpu()
    grid_test = torchvision.utils.make_grid(test, nrow=10)
    
    plt.imshow(grid_test.detach().numpy().transpose(1, 2, 0))
    plt.show()