In [1]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim

# data prepare

In [2]:
use_cuda = torch.cuda.is_available()

root = './data'
if not os.path.exists(root):
    os.mkdir(root)
    
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# if not exist, download mnist dataset
train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

In [3]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

==>>> total trainning batch number: 469
==>>> total testing batch number: 79


In [4]:
inputs, _ = next(iter(train_loader))
img = inputs[4][0]

import matplotlib.pyplot as plt
plt.imshow(img, cmap='gray', interpolation='nearest')

<matplotlib.image.AxesImage at 0x1d3f0b1d860>

# MODEL

In [5]:
class VAE(nn.Module):

    def __init__(self, input_size = 28*28, output_size = 28*28):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_size, 400)
        self.fc2_1 = nn.Linear(400, 20)
        self.fc2_2 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, output_size)

    def encode(self, input):
        x = F.relu(self.fc1(input))
        return self.fc2_1(x), self.fc2_2(x)
    
    def reparameterize(self, mu, var):
        std = torch.exp(0.5*var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(x))
        
    def forward(self, x):
        mu, var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, var)
        return self.decode(z), mu, var   

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average = False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

# Training

In [1]:
model = VAE()

if use_cuda:
    model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(1, 20):
    for batch_idx, (x, _) in enumerate(train_loader):
        
        optimizer.zero_grad()
        if use_cuda:
            x = x.cuda()
            
        result, mu, var = model(x)
        loss = loss_function(result, x, mu, var)
        loss.backward()
        optimizer.step()
        
        if ((batch_idx+1) == len(train_loader)):
            print('==>>> epoch: {}, loss: {:.6f}'.format(epoch,loss.data))
    
print("===================Finished!=================== ")

NameError: name 'VAE' is not defined

# Fake image

In [None]:
z = (torch.randn((10, 10)).normal_(mean = 4, std = 0)).cuda()
fake_img = model.decode(z)

import matplotlib.pyplot as plt
for i in range(1, 10):
    img = fake_img[i].cpu().detach().numpy()
    img = img.reshape(28, 28)
    plt.figure()
    plt.imshow(img, cmap='gray', interpolation='nearest')

# Recon img

In [None]:
inputs, _ = next(iter(train_loader))
inputs = inputs[0:10].cuda()
recon_img, mean, var = model(inputs)

for i in range(1, 10):
    img = recon_img[i].cpu().detach().numpy()
    img = img.reshape(28, 28)
    print('mean: {}, var: {}'.format(mean[i].cpu().detach().numpy(), var[i].cpu().detach().numpy()))
    img2 = inputs[i].cpu().detach().numpy()
    img2 = img2.reshape(28, 28)
    plt.figure()
    plt.imshow(img, cmap='gray', interpolation='nearest')
    plt.figure()
    plt.imshow(img2, cmap='gray', interpolation='nearest')