In [None]:
import einops as ein
from einops.layers.torch import Rearrange

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()

        self.depth = 4096
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential (
            nn.Conv2d (1, 4, 5),
            nn.ReLU (),
            nn.Conv2d (4, 8, 5),
            nn.ReLU (),
            nn.Conv2d (8, 16, 5, stride=1),
            nn.ReLU (),
            Rearrange ("batch a b c -> batch (a b c)")
        )
        self.mu_encoder = nn.Linear (self.depth, self.latent_dim)
        self.log_var_encoder = nn.Linear (self.depth, self.latent_dim)


        self.decoder = nn.Sequential (
            nn.Linear (self.latent_dim, self.depth),
            nn.ReLU (),
            Rearrange ("batch (a b c) -> batch a b c", a=16, b=16, c=16),
            nn.ConvTranspose2d (16, 8, 5, stride=1, output_padding=0),
            nn.ReLU (),
            nn.ConvTranspose2d (8, 4, 5),
            nn.ReLU (),
            nn.ConvTranspose2d (4, 1, 5),
            nn.Sigmoid ()
        )

    def encode (self, x):
        h1 = self.encoder (x)
        return self.mu_encoder (h1), self.log_var_encoder (h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder (z)

    def forward(self, x): # model(data) = model.forward(data)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = VAE(2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

batch_size2 = 128
log_interval2 = 10
epochs2 = 10

#torch.manual_seed(1) # args.seed

kwargs = {'num_workers': 4, 'pin_memory': True} if device == "cuda" else {} # args.cuda

# Get train and test data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size2, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size2, shuffle=True, **kwargs)


In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD # -ELBO


def train(epoch):
    model.train() # so that everything has gradients and we can do backprop and so on...
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad() # "reset" gradients to 0 for text iteration
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward() # calc gradients
        train_loss += loss.item()
        optimizer.step() # backpropagation

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad(): # no_grad turns of gradients...
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


In [None]:
num_rows = 20
a = torch.linspace(-8., 8.,  num_rows)
x_t = a.repeat(num_rows)
x_t = x_t.view(num_rows,num_rows)
y_t = x_t.t().flip(0)
art_nums = torch.stack((x_t, y_t)).view(2,-1).t().to(device)
print (art_nums.size())
#print (art_nums)

In [None]:
for epoch in range(1, epochs2 + 1):
    train(epoch)
    test(epoch)
    if model.latent_dim == 2:
        with torch.no_grad():
            sample = model.decode(art_nums).cpu()
            save_image(sample.view(-1, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png', nrow=num_rows)