# Variational Autoencoder with Normalizing Flows

In [None]:
# Import packages
import numpy as np
import torch
from torch import nn, optim
from torchvision import datasets, transforms

from tqdm import tqdm
from matplotlib import pyplot as plt

import normflows as nf

In [None]:
# Hyperparameters
torch.manual_seed(0)

batch_size = 64
num_samples = 32
n_flows = 40
n_bottleneck = 40
hidden_units_encoder = np.array([28 ** 2, 512, 256, n_bottleneck * 2])
hidden_units_decoder = np.array([n_bottleneck, 256, 512, 28 ** 2])
flow_type = 'Planar'
n_epochs = 15

enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

In [None]:
# Get dataloader
class BinaryTransform():
    def __init__(self, thresh=0.5):
        self.thresh = thresh

    def __call__(self, x):
        return (x > self.thresh).type(x.type())
transform=transforms.Compose([transforms.ToTensor(), BinaryTransform()])
mnist_train = datasets.MNIST('../datasets', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST('../datasets', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [None]:
# Set up model
prior = torch.distributions.MultivariateNormal(torch.zeros(n_bottleneck, device=device),
                                               torch.eye(n_bottleneck, device=device))
encoder_nn = nf.nets.MLP(hidden_units_encoder)
encoder = nf.distributions.NNDiagGaussian(encoder_nn)
decoder_nn = nf.nets.MLP(hidden_units_decoder)
decoder = nf.distributions.NNBernoulliDecoder(decoder_nn)

if flow_type == 'Planar':
    flows = [nf.flows.Planar((n_bottleneck,)) for k in range(n_flows)]
elif flow_type == 'Radial':
    flows = [nf.flows.Radial((n_bottleneck,)) for k in range(n_flows)]
elif flow_type == 'RealNVP':
    b = torch.tensor(n_bottleneck // 2 * [0, 1] + n_bottleneck % 2 * [0])
    flows = []
    for i in range(n_flows):
        s = nf.nets.MLP([n_bottleneck, n_bottleneck])
        t = nf.nets.MLP([n_bottleneck, n_bottleneck])
        if i % 2 == 0:
            flows += [nf.flows.MaskedAffineFlow(b, t, s)]
        else:
            flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
else:
    raise NotImplementedError

nfm = nf.NormalizingFlowVAE(prior, encoder, flows, decoder)
nfm.to(device)

In [None]:
# Train model
log_intv = 100
optimizer = optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-4)
for epoch in range(n_epochs):
    progressbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_n, (x, n) in progressbar:
        x = x.to(device)
        optimizer.zero_grad()
        z, log_q, log_p = nfm(x.view(x.size(0) * x.size(1), 28 ** 2), num_samples)
        mean_log_q = torch.mean(log_q)
        mean_log_p = torch.mean(log_p)
        loss = mean_log_q - mean_log_p
        loss.backward()
        optimizer.step()
        progressbar.update()
    progressbar.close()
    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_n * len(x), len(train_loader.dataset),
                       100. * batch_n / len(train_loader),
                       loss.item()))

In [None]:
x_out = nfm.decoder(torch.randn((1, n_bottleneck), device=device))
x_np = x_out.view((28, 28)).to('cpu').detach().numpy()
plt.imshow(x_np)
plt.show()