In [18]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distr

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

### Dataset definition

In [13]:
batch_size = 100
num_workers = 4

# Transforms
transform = transforms.ToTensor()

# Train
train_mnist = datasets.MNIST("./data", train=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=batch_size, num_workers=num_workers)

# Test
test_mnist = datasets.MNIST("./data", train=False, transform=transform)
test_loader = DataLoader(test_mnist, shuffle=True, batch_size=batch_size, num_workers=num_workers)

### Network definition

In [19]:
class Encoder(nn.Module):
    def __init__(self, l0, l1):
        super(Encoder, self).__init__()
         
        # Define layers
        self.fc0 = nn.Linear(28*28, l0)
        self.fc1_mu = nn.Linear(l0, l1)
        self.fc1_sigma = nn.Linear(l0, l1)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        
        # Hidden
        x = self.fc0(x)
        x = F.relu(x)
        
        # Latent
        mu = self.fc1_mu(x)
        logvar = self.fc1_sigma(x)
        sigma = torch.exp(0.5 * logvar)
        
        return mu, sigma
        
class Decoder(nn.Module):
    def __init__(self, l0, l1):
        super(Decoder, self).__init__()
        
        # Define layers
        self.fc0 = nn.Linear(l0, l1)
        self.fc1 = nn.Linear(l1, 28*28)
        
    def forward(self, x):
        
        # Hidden
        x = self.fc0(x)
        x = F.relu(x)
        
        # Output
        x = self.fc1(x)
        x = F.sigmoid(x)
        
        return x
        
class VAE(nn.Module):
    def __init__(self, l0, l1):
        super(VAE, self).__init__()
        self.encoder = Encoder(l1, l0)
        self.decoder = Decoder(l0, l1)
        
    def forward(self, x):
        # Generate deterministic mean and std
        z_mu, z_sigma = self.encoder(x)
        
        # Sample epsilon and make z stochastic
        eps = torch.randn_like(z_sigma)
        z = z_mu + z_sigma * epsilon
        
        # Perform decoding, or sample from X
        y = self.decoder(z)
        
        return y, z_mu, z_sigma

### Loss definition

In [15]:
def sqr(x):
    res = torch.pow(x, 2)
    return(res)

def elbo_kld(mu, sigma):
    kld = 0.5 * (1 + torch.log(sqr(sigma)) - sqr(mu) - sqr(sigma)).sum()
    return kld

def ELBO(mu, sigma, output, input):
    reconstruction_error = F.binary_cross_entropy(output, input, reduction="sum")
    kld = elbo_kld(mu, sigm)
    return reconstruction_error + kld

### Training

In [17]:
model = VAE(400, 20)
optim = optim.Adam(model.parameters())

In [None]:
epochs = 100

for _ in range(epochs):
    for b_idx, batch in enumerate(train_loader):
        x, _ = batch
        net_x, latent_z = model(x)
        
        loss = ELBO()