In [1]:
import torch
import torchvision

from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import optim
from torch import nn
from torchvision.utils import save_image

First let's download mnist and make a generator from it

In [2]:
batch_size = 32


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

I'm not actually going to test on GPU but it should work or at least give you an idea of how to make it work on CPU and GPU (which I'm not a fan of in oytorch)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Let's define an encoder and a decoder:

In [47]:
class Encoder(nn.Module):
    def __init__(self, x_size, h_size, z_size):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(x_size, h_size)
        self.mu_gen = nn.Linear(h_size, z_size)
        # make the output to be the logarithm 
        # i.e will have to take the exponent
        # which forces variance to be positive
        # not that this is the diagonal of the covariance
        self.log_var_gen = nn.Linear(h_size, z_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.mu_gen(x)
        log_var = self.log_var_gen(x)
        return mu, log_var

In [48]:
class Decoder(nn.Module):
    def __init__(self, x_size, h_size, z_size):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(z_size, h_size)
        self.fc3 = nn.Linear(h_size, x_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc3(x))
        # black and white MNIST => sigmoid for each pixel
        x = torch.sigmoid(x) 
        return x

A VAE is simply a container of the encoder + decoder + reparametrization trick:

In [49]:
class VAE(nn.Module):
    def __init__(self, x_size, h_size, z_size):
        super(VAE, self).__init__()
        self.x_size = x_size
        self.z_size = z_size
        self.encoder = Encoder(x_size, h_size, z_size)
        self.decoder = Decoder(x_size, h_size, z_size)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var) # square root in exponent => std
        eps = torch.randn_like(std)
        z = std * eps + mu
        return z

    def forward(self, x):
        # make image linear (i.e vector form)
        x = x.view(-1, self.x_size)
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(z)
        return x_hat, mu, log_var

Let's define the model:


In [50]:
x_size = 28 * 28 # mnist image
h_size = 256
z_size = 16
model = VAE(x_size, h_size, z_size).to(device) # migrates to CUDA if you can

In [51]:
model

VAE(
  (encoder): Encoder(
    (fc1): Linear(in_features=784, out_features=256, bias=True)
    (mu_gen): Linear(in_features=256, out_features=16, bias=True)
    (log_var_gen): Linear(in_features=256, out_features=16, bias=True)
  )
  (decoder): Decoder(
    (fc1): Linear(in_features=16, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=784, bias=True)
  )
)

In [52]:
def loss_function(x_hat, x, mu, log_var, beta=1):
    """Compute the ELBO loss"""
    x_size = x_hat.size(-1)
    # black or white image => use sigmoid for each pixel
    rec_loss = F.binary_cross_entropy(x_hat, x.view(-1, x_size), reduction='sum')
    # closed form solution for gaussian prior and posterior
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    vae_loss = rec_loss + beta * kl_div
    return vae_loss

In [53]:
class Trainer:
    def __init__(self, model, optimizer=optim.Adam, loss_function=loss_function):
        self.model = model 
        self.optimizer = optimizer(self.model.parameters())
        self.loss_function = loss_function
        self.epoch = 0
        
    def __call__(self, train, test, n_epochs=10):
        self.epoch = 0
        for _ in range(n_epochs):
            self._train_epoch(train)
            self._test_epoch(test)
            with torch.no_grad():
                sample = torch.randn(64, self.model.z_size).to(device)
                sample = model.decoder(sample).cpu()  # make sure on cpu
                save_image(sample.view(64, 1, 28, 28),
                           '../results/sample_' + str(self.epoch) + '.png')
        
    def _train_epoch(self, train):
        self.epoch += 1
        model.train() # make sure train mode (e.g. dropout)
        train_loss = 0
        for i, (x, _) in enumerate(train):
            x = x.to(device) # data on GPU 
            self.optimizer.zero_grad() # reset all previous gradients
            x_hat, mu, log_var = model(x)
            loss = self.loss_function(x_hat, x, mu, log_var)
            loss.backward() # backpropagate (i.e store gradients)
            train_loss += loss.item() # compute loss (.item because only the value)
            self.optimizer.step() # take optimizing step (~gradient descent)

        print('Epoch: {} Train loss: {:.4f}'.format(
              self.epoch, train_loss / len(train.dataset)))
        
    def _test_epoch(self, test):
        model.eval() # make sure evaluate mode (e.g. dropout)
        test_loss = 0
        with torch.no_grad():  # stop gradients computation
            for i, (x, _) in enumerate(test):
                x = x.to(device)
                x_hat, mu, log_var = model(x)
                test_loss += loss_function(x_hat, x, mu, log_var).item()

        print('Test loss: {:.4f}'.format(test_loss/len(test.dataset)))


run all:

In [54]:
trainer = Trainer(model)

In [55]:
%%time
trainer(train_loader, test_loader)

Epoch: 1 Train loss: 543.5655
Test loss: 543.4254
Epoch: 2 Train loss: 543.4148
Test loss: 543.3198
Epoch: 3 Train loss: 543.3404
Test loss: 543.2177
Epoch: 4 Train loss: 543.2585
Test loss: 543.1621
Epoch: 5 Train loss: 543.2243
Test loss: 543.1746
Epoch: 6 Train loss: 543.2243
Test loss: 543.1531
Epoch: 7 Train loss: 543.2111
Test loss: 543.1439
Epoch: 8 Train loss: 543.2008
Test loss: 543.1417
Epoch: 9 Train loss: 543.1977
Test loss: 543.1113
Epoch: 10 Train loss: 543.1822
Test loss: 543.1046
CPU times: user 8min 17s, sys: 8.42 s, total: 8min 25s
Wall time: 15min 43s


If you're interested in more details I actually wrote a quick post about VAE in the last days: https://yanndubs.github.io/machine-learning-glossary/#variational-autoencoders 