In [13]:
"""
@Description: VAE
@Author: Dezhan Tu
@LastEditTime: 08/23/2020
"""

import numpy as np
import os

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torch.nn import functional as F
import matplotlib.pyplot as plt


# Define hyperparameters
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
BATCH_SIZE = 128

In [14]:
# Data utility

# Convert data to torch.FloatTensor
img_transform  = transforms.Compose([transforms.ToTensor()])

# Download MNIST Dataset
train_dataset = MNIST(
    root="./dataset", 
    train=True, 
    transform=img_transform, 
    download=True
)

test_dataset = MNIST(
    root="./dataset", 
    train=False, 
    transform=img_transform, 
    download=True
)

# Create training and testing datasets
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False
)


def save_epoch_img(img, epoch):
    """
    save img during training process
    """
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, './VAE_Result/VAE_{}.png'.format(epoch))


def make_dir():
    """
    make directory to store result
    """
    if not os.path.exists('VAE_Result'):
        os.makedirs('VAE_Result')
        
make_dir()    #Create result directory

In [15]:
class VAE(nn.Module):
    """
    class for variational autoencoder(VAE)
    """
    def __init__(self):
        """
        init VAE network structure
        """
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(784, 256)
        self.fc21 = nn.Linear(256, 32)   #mean
        self.fc22 = nn.Linear(256, 32)   #variance
        self.fc3 = nn.Linear(32, 256)
        self.fc4 = nn.Linear(256, 784)     
        
    def encode(self, x):
        """
        encoding process
        """
        h1 = torch.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, log_var):
        """
        generate latent vector randomly
        """
        std = torch.exp(log_var * 0.5)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """
        decoding process
        """
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x = self.decode(z)
        return x, mu, log_var
    
# Reconstruction + KL divergence losses
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

# Training Process

In [18]:
# Use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate an VAE
model = VAE().to(device)

# Mean-Squared Error Loss
criterion = nn.MSELoss()

# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training Process
loss_output = []
loss = 0
for epoch in range(NUM_EPOCHS):
    for img, _ in train_loader:
        img = img.view(-1, 784).to(device)    # Reshape mini-batch data
        optimizer.zero_grad()                 # Reset the gradients back to zero
        recon_batch, mu, logvar = model(img)  # Reconstruction
        train_loss = loss_function(recon_batch, img, mu, logvar)   # Reconstruction + KL divergence losses
        train_loss.backward()                 # Compute accumulated gradients
        optimizer.step()                      # Parameters update
        loss += train_loss.item()             # Add the mini-batch training loss to epoch loss

        
    loss = loss / len(train_loader)           # Compute the epoch training loss
    loss_output.append(loss)    

            
    # Display the training loss
    print("Epoch : {}/{}, Training loss = {:.3f}".format(epoch + 1, NUM_EPOCHS, loss))

Epoch : 1/100, Training loss = 22682.541
Epoch : 2/100, Training loss = 16648.973
Epoch : 3/100, Training loss = 15372.358
Epoch : 4/100, Training loss = 14789.311
Epoch : 5/100, Training loss = 14455.871
Epoch : 6/100, Training loss = 14246.921
Epoch : 7/100, Training loss = 14091.546
Epoch : 8/100, Training loss = 13980.863
Epoch : 9/100, Training loss = 13891.288
Epoch : 10/100, Training loss = 13822.959
Epoch : 11/100, Training loss = 13758.233
Epoch : 12/100, Training loss = 13715.204
Epoch : 13/100, Training loss = 13676.213
Epoch : 14/100, Training loss = 13636.315
Epoch : 15/100, Training loss = 13606.857
Epoch : 16/100, Training loss = 13582.937
Epoch : 17/100, Training loss = 13560.845
Epoch : 18/100, Training loss = 13533.290
Epoch : 19/100, Training loss = 13523.884
Epoch : 20/100, Training loss = 13498.326
Epoch : 21/100, Training loss = 13485.002
Epoch : 22/100, Training loss = 13463.412
Epoch : 23/100, Training loss = 13453.164
Epoch : 24/100, Training loss = 13445.775
E

# Testing Process

In [20]:
i = 0
for img,_ in test_loader:
    img = img.view(-1, 784).to(device)    # Reshape mini-batch data
    recon_batch, _, _ = model(img)  # Testing phase 
    recon_batch = recon_batch.view(recon_batch.size(0), 1, 28, 28).cpu().data
    i = i + 1
    save_image(recon_batch.cpu(), 'VAE_Result/reconstruction_{}.png'.format(i))