# Convolutional VAE 

In [None]:
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.autograd import Variable

### Settings

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
image_size = 28
hidden_size = 800
latent_size = 2

learning_rate = 0.0002
num_epochs = 50

In [None]:
train_data = torchvision.datasets.MNIST(root='./../data/MNIST/', train=True, transform=transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST(root='./../data/MNIST/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

### Model

In [None]:
class CVEncoder(nn.Module):
    
    def __init__(self):
        super(CVEncoder, self).__init__()
        
        self.conv = nn.Sequential(
                    nn.Conv2d(1, 8, 3, padding=1),
                    nn.BatchNorm2d(8),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(8, 16, 3, padding=1),
                    nn.BatchNorm2d(16),
                    nn.ReLU(),
                    nn.MaxPool2d(2, 2),
                    nn.Conv2d(16, 32, 3, padding=1),
                    nn.ReLU()
        )
        
        self.i2mu = nn.Sequential(
                    nn.Linear(32*7*7, hidden_size),
                    nn.ReLU(),
                    nn.Linear(hidden_size, hidden_size //2),
                    nn.ReLU(),
                    nn.Linear(hidden_size // 2, latent_size)
        )
        
        self.i2log_var = nn.Sequential(
                    nn.Linear(32*7*7, hidden_size),
                    nn.ReLU(),
                    nn.Linear(hidden_size, hidden_size // 2),
                    nn.ReLU(),
                    nn.Linear(hidden_size // 2, latent_size)
        )
        
        self.relu = nn.ReLU()      
        
    def forward(self, input):
        
        output = self.conv(input)
        output = output.view(batch_size, -1)
        output = self.relu(output)
        
        mu = self.i2mu(output)
        log_var = self.i2log_var(output)
        
        reparam = self.reparameterize(mu, log_var)
        
        return mu, log_var, reparam
        
    def reparameterize(self, mu, log_var):
        
        std = torch.exp(log_var / 2)
        eps = torch.randn(std.size()).to(device)
        
        return mu + std * eps        

In [None]:
class CVDecoder(nn.Module):
    
    def __init__(self):
        super(CVDecoder, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(latent_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 32*7*7),
            nn.ReLU()
        )
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 1, 3, 1, 1),
            nn.BatchNorm2d(1)
        )
        
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input):
        
        output = self.fc(input)
        output = self.relu(output)
        output = output.view(batch_size, 32, 7, 7)
        
        output = self.deconv(output)
        output = self.sigmoid(output)
        output = output.view(batch_size, 1, 28, 28)
        
        return output

In [None]:
class CVAE(nn.Module):
    
    def __init__(self):
        super(CVAE, self).__init__()
        
        self.encoder = CVEncoder()
        self.decoder = CVDecoder()
        
        self.BCELoss = nn.BCELoss(reduction='sum')
        
    def forward(self, input):
        
        mu, log_var, reparam = self.encoder(input)
        reconst_input = self.decoder(reparam)
        
        BCE_loss = self.BCELoss(reconst_input, input)
        KLD_loss = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())
        
        return reconst_input, BCE_loss + KLD_loss

In [None]:
model = CVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        
        images = Variable(images).to(device)
        output, loss = model(images)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], loss [{:.4f}]'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

### Result

In [None]:
out_img = torch.squeeze(output.cpu().data)

for i in range(output.size()[0]):
    
    fig = plt.figure()
    origin = fig.add_subplot(1, 2, 1)
    generated = fig.add_subplot(1, 2, 2)
    
    origin.imshow(torch.squeeze(images[i]), cmap='gray')
    generated.imshow(out_img[i], cmap='gray')
    
    fig.show()