In [4]:
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable

In [163]:
class VAE(nn.Module):
    def __init__(self, inp_size, hid_size, z_size):
        super(VAE, self).__init__()
        # Encoder
        self.fc1 = nn.Linear(inp_size, hid_size)
        self.drop1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(hid_size, hid_size)
        self.drop2 = nn.Dropout(0.2)
        
        self.fc31 = nn.Linear(hid_size, z_size)
        self.fc32 = nn.Linear(hid_size, z_size)
        
        # Decoder
        self.fc4 = nn.Linear(z_size, hid_size)
        self.drop4 = nn.Dropout(0.2)
        
        self.fc5 = nn.Linear(hid_size, hid_size)
        self.drop5 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(hid_size, inp_size)
        
    def encode(self, x):
        h = self.drop1(F.relu(self.fc1(x)))
        # h = self.drop2(F.relu(self.fc2(h)))
        return self.fc31(h), self.fc32(h)
    
    def reparametrize(self, mu, logvar):
        if self.training:
            # logvar = log(sigma**2)
            # logvar = 2 * log(sigma)
            # sigma = exp(logvar/2)
            std = logvar.mul(0.5).exp()
            x = Variable(std.data.new(std.size()).normal_())
            return x.mul(std).add(mu)
        else:
            return mu
    
    def decode(self, z):
        h = self.drop4(F.relu(self.fc4(z)))
        # h = self.drop5(F.relu(self.fc5(h)))
        return F.sigmoid(self.fc6(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar            

In [164]:
model = VAE(inp_size=784, hid_size=400, z_size=20)
x_, mu, logvar = model(Variable(torch.randn(784)))

In [165]:
def loss_function(x_, x, mu, logvar):
    BCE = F.binary_cross_entropy(x_, x)
    # KLD = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD / (len(x) * len(mu))

In [172]:
import sys

def train(epoch, model, optimizer, train_loader, feat_size):
    model.train()
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = Variable(x.view(-1, feat_size))
        optimizer.zero_grad()
        x_, mu, logvar = model(x)
        loss = loss_function(x_, x, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        
        if batch_idx % 10 == 0:
            sys.stdout.write(f'\rTrain Epoch: {epoch + 1} '
                f'[{(batch_idx) * len(x)}/{len(train_loader.dataset)} '
                f'({100. * (batch_idx) / len(train_loader):.0f}%)]\t'
                f'Loss: {loss.data[0] / len(x):.6f}')
            sys.stdout.flush()
    print()
    print(f'=====> Epoch: {epoch + 1} '
          f'Average loss: {train_loss / len(train_loader.dataset):.4f}')

In [173]:
def test(epoch, model, test_loader, feat_size):
    model.eval()
    test_loss = 0
    for batch_idx, (x, _) in enumerate(test_loader):
        x = Variable(x.view(-1, feat_size), volatile=True)
        x_, mu, logvar = model(x)
        loss = loss_function(x_, x, mu, logvar)
        test_loss += loss.data[0]
    print(f'====> Test set loss: {test_loss / len(test_loader.dataset):.4f}')

In [174]:
from torchvision.utils import save_image
from torchvision import datasets, transforms

In [175]:
BATCH_SIZE = 128

In [176]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, 
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True)

In [177]:
model = VAE(inp_size=784, hid_size=800, z_size=20)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [178]:
for epoch in range(10):
    train(epoch, model, optimizer, train_loader, 784)
    test(epoch, model, test_loader, 784)
    sample = Variable(torch.randn(BATCH_SIZE, 20))
    sample = model.decode(sample)
    save_image(sample.data.view(BATCH_SIZE, 1, 28, 28),
               f'../../results/robert/sample_mnist_{epoch + 1}.png')

=====> Epoch: 1 Average loss: 0.0020
====> Test set loss: 0.0017
=====> Epoch: 2 Average loss: 0.0018
====> Test set loss: 0.0016
=====> Epoch: 3 Average loss: 0.0018
====> Test set loss: 0.0016
=====> Epoch: 4 Average loss: 0.0017
====> Test set loss: 0.0016
=====> Epoch: 5 Average loss: 0.0017
====> Test set loss: 0.0016
=====> Epoch: 6 Average loss: 0.0017
====> Test set loss: 0.0016
=====> Epoch: 7 Average loss: 0.0017
====> Test set loss: 0.0016
=====> Epoch: 8 Average loss: 0.0017
====> Test set loss: 0.0015
=====> Epoch: 9 Average loss: 0.0017
====> Test set loss: 0.0015
=====> Epoch: 10 Average loss: 0.0017
====> Test set loss: 0.0015
