In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [18]:
batch_size = 100
learning_rate = 0.0005
num_epoch = 1
hidden_size = 10

In [16]:
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [48]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Sequential(
                        nn.Conv2d(1,8,3,padding=1),   # batch x 8 x 28 x 28
                        nn.BatchNorm2d(8),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2), 
                        nn.Conv2d(8,16,3,padding=1),  # batch x 16 x 14 x 14
                        nn.BatchNorm2d(16),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(16,32,3,padding=1),  # batch x 32 x 7 x 7
                        nn.ReLU(),)
        
        self.fc2_1 = nn.Sequential(
                        nn.Linear(32*7*7, 800),
                        nn.Linear(800, hidden_size),)
        
        self.fc2_2 = nn.Sequential(
                        nn.Linear(32*7*7, 800),
                        nn.Linear(800, hidden_size),)
        
        
        self.fc3 = nn.Sequential(
                        nn.Linear(hidden_size,800),
                        nn.BatchNorm1d(800),
                        nn.ReLU(),
                        nn.Linear(800,1568),
                        nn.ReLU(),)
        
        self.fc4 = nn.Sequential(
                        nn.ConvTranspose2d(32,16,3,2,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(16),
                        nn.ConvTranspose2d(16,8,3,2,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(8),
                        nn.ConvTranspose2d(8,1,3,1,1),
                        nn.BatchNorm2d(1),)
        
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def encoder(self, x):
        #x = x.view(batch_size,-1)
        out = self.fc1(x)
        out = out.view(batch_size,-1)
        out = self.relu(out)
        mu = self.fc2_1(out)
        log_var = self.fc2_2(out)
                
        return mu,log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        out = self.fc3(z)
        out = self.relu(out)
        out = out.view(batch_size,32,7,7)
        out = self.fc4(out)
        out = self.sigmoid(out)
        out = out.view(batch_size,28,28,1)
        
        return out 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 1, 28, 28))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE()

In [34]:
use_cuda = torch.cuda.is_available()                   # check if GPU exists
device = torch.device("cuda" if use_cuda else "cpu")   # use CPU or GPU

In [35]:
vae

VAE(
  (fc1): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
  )
  (fc2_1): Sequential(
    (0): Linear(in_features=1568, out_features=800, bias=True)
    (1): Linear(in_features=800, out_features=10, bias=True)
  )
  (fc2_2): Sequential(
    (0): Linear(in_features=1568, out_features=800, bias=True)
    (1): Linear(in_features=800, out_features=10, bias=True)
  )
  (fc3): Sequential(
    (0): Linear(in_features=10, 

In [47]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 1, 28, 28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [41]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [42]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [49]:
for epoch in range(1, 20):
    train(epoch)
    test()

  after removing the cwd from sys.path.


====> Epoch: 1 Average loss: 642.9926
====> Test set loss: 643.2455
====> Epoch: 2 Average loss: 642.9936
====> Test set loss: 641.9545
====> Epoch: 3 Average loss: 642.9802
====> Test set loss: 643.3165
====> Epoch: 4 Average loss: 642.9280
====> Test set loss: 643.0377
====> Epoch: 5 Average loss: 642.8713
====> Test set loss: 644.0555
====> Epoch: 6 Average loss: 642.9995
====> Test set loss: 642.4959
====> Epoch: 7 Average loss: 643.0260
====> Test set loss: 642.8975
====> Epoch: 8 Average loss: 642.8473
====> Test set loss: 645.5510
====> Epoch: 9 Average loss: 642.9770
====> Test set loss: 641.4386
====> Epoch: 10 Average loss: 642.8620
====> Test set loss: 643.1968
====> Epoch: 11 Average loss: 642.9317
====> Test set loss: 644.4400
====> Epoch: 12 Average loss: 642.9573
====> Test set loss: 643.4304
====> Epoch: 13 Average loss: 642.8581
====> Test set loss: 643.0271
====> Epoch: 14 Average loss: 642.8774
====> Test set loss: 644.6484
====> Epoch: 15 Average loss: 642.9489
====