# Very Deep Variational Autoencoder

### Code taken and adapted from: https://github.com/vvvm23/vdvae

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

import numpy as np
import itertools
import datetime
from tqdm.notebook import tqdm, trange

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!mkdir ./imgs/
!mkdir ./saved\_checkpoints/

### Checkpoint | Helper | HyperParameters

In [None]:
class Checkpoint(dict):
    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value

In [None]:
info = lambda s: print(f"\33[92m> {s}\33[0m")
error = lambda s: print(f"\33[31m! {s}\33[0m")
warning = lambda s: print(f"\33[94m$ {s}\33[0m")

def get_device(try_cuda):
    if try_cuda == False:
        info("CUDA disabled by hyperparameters.")
        return torch.device('cpu')
    if torch.cuda.is_available():
        info("CUDA is available.")
        return torch.device('cuda')
    error("CUDA is unavailable but selected in hyperparameters.")
    error("Falling back to default device.")
    return torch.device('cpu')

In [None]:
class Hyperparameters(dict):
    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value

def get_parameters(task):
    HPS = Hyperparameters()
    HPS['cuda'] = True
    HPS['checkpoint'] = 5
    HPS['tqdm'] = True

    if task == 'cifar10':
        HPS['dataset'] = 'cifar10'
        HPS['batch_size'] = 16

        HPS['in_channels'] = 3
        HPS['h_width'] = 64
        HPS['m_width'] = 32
        HPS['z_dim'] = 16
        HPS['nb_blocks'] = 4
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 1_100_00
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    elif task == 'stl10':
        HPS['dataset'] = 'cifar10'
        HPS['batch_size'] = 16

        HPS['in_channels'] = 3
        HPS['h_width'] = 128
        HPS['m_width'] = 64
        HPS['z_dim'] = 32
        HPS['nb_blocks'] = 5
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 1_100_00
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    elif task == 'mnist':
        HPS['dataset'] = 'mnist'
        HPS['batch_size'] = 32

        HPS['in_channels'] = 1
        HPS['h_width'] = 32
        HPS['m_width'] = 16
        HPS['z_dim'] = 8
        HPS['nb_blocks'] = 2
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 100_000
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    else:
        error("Unrecognized HPS task! Exiting..")
        exit()

    return HPS

### VD-VAE

Helper functions and classes

In [None]:
"""
    Encoder Components:
        - Encoder, contains all the EncoderBlocks and manages data flow through them.
        - EncoderBlock, contains sub-blocks of residual units and a pooling layer.
        - ResidualBlock, contains a block of residual connections, as described in the paper (1x1,3x3,3x3,1x1)
            - We could slightly adapt, and make it a ReZero connection. Needs some testing.

    Decoder Components:
        - Decoder, contains all DecoderBlocks and manages data flow through them.
        - DecoderBlock, contains sub-blocks of top-down units and an unpool layer.
        - TopDownBlock, implements the topdown block from the original paper.

    All is encapsulated in the main VAE class.

"""

class ConvBuilder:
    def _bconv(in_dim, out_dim, kernel_size, stride, padding):
        conv = nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding)
        return conv
    def b1x1(in_dim, out_dim):
        return ConvBuilder._bconv(in_dim, out_dim, 1, 1, 0)
    def b3x3(in_dim, out_dim):
        return ConvBuilder._bconv(in_dim, out_dim, 3, 1, 1)

"""
    Diagonal Gaussian Distribution and loss.
    Taken directly from OpenAI implementation
    Decorators means these functions will be compiled as TorchScript
"""
@torch.jit.script
def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2):
    return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2)

@torch.jit.script
def draw_gaussian_diag_samples(mu, logsigma):
    eps = torch.empty_like(mu).normal_(0., 1.)
    z = torch.exp(logsigma) * eps + mu
    return z

"""
    Helper module to call super().__init__() for us
"""
class HelperModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.build(*args, **kwargs)

    def build(self, *args, **kwargs):
        raise NotImplementedError

Encoder Components

In [None]:
class ResidualBlock(HelperModule):
    def build(self, in_width, hidden_width, rezero=False): # hidden_width should function as a bottleneck!
        self.conv = nn.ModuleList([
            ConvBuilder.b1x1(in_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b1x1(hidden_width, in_width)
        ])
        if rezero:
            self.gate = nn.Parameter(torch.tensor(0.0))
        else:
            self.gate = 1.0

    def forward(self, x):
        xh = x
        for l in self.conv:
            xh = l(F.gelu(xh))
        y = x + self.gate*xh
        return y

class EncoderBlock(HelperModule):
    def build(self, in_dim, middle_width, nb_r_blocks, downscale_rate):
        self.downscale_rate = downscale_rate
        self.res_blocks = nn.ModuleList([
            ResidualBlock(in_dim, middle_width)
        for _ in range(nb_r_blocks)])

    def forward(self, x):
        y = x
        for l in self.res_blocks:
            y = l(y)
        a = y
        y = F.avg_pool2d(y, kernel_size=self.downscale_rate, stride=self.downscale_rate)
        return y, a # y is input to next block, a is activations to topdown layer

class Encoder(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, nb_encoder_blocks, nb_res_blocks=3, downscale_rate=2):
        self.in_conv = ConvBuilder.b3x3(in_dim, hidden_width)
        self.enc_blocks = nn.ModuleList([
            EncoderBlock(hidden_width, middle_width, nb_res_blocks, 1 if i==(nb_encoder_blocks-1) else downscale_rate)
        for i in range(nb_encoder_blocks)])

        # TODO: could just pass np.sqrt( ... ) value to EncoderBlock, rather than this weird loop
        # it is the same in every block.
        for be in self.enc_blocks:
            for br in be.res_blocks:
                br.conv[-1].weight.data *= np.sqrt(1 / (nb_encoder_blocks*nb_res_blocks))

    def forward(self, x):
        x = self.in_conv(x)
        activations = [x]
        for b in self.enc_blocks:
            x, a = b(x)
            activations.append(a)
        return activations

Decoder Components

In [None]:
class Block(HelperModule):
    def build(self, in_width, hidden_width, out_width): # hidden_width should function as a bottleneck!
        self.conv = nn.ModuleList([
            ConvBuilder.b1x1(in_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b1x1(hidden_width, out_width)
        ])

    def forward(self, x):
        for l in self.conv:
            x = l(F.gelu(x))
        return x

class TopDownBlock(HelperModule):
    def build(self, in_width, middle_width, z_dim):
        self.cat_conv = Block(in_width*2, middle_width, z_dim*2) # parameterises mean and variance
        self.prior = Block(in_width, middle_width, z_dim*2 + in_width) # parameterises mean, variance and xh
        self.out_res = ResidualBlock(in_width, middle_width)
        self.z_conv = ConvBuilder.b1x1(z_dim, in_width)
        self.z_dim = z_dim

    def forward(self, x, a):
        xa = torch.cat([x,a], dim=1)
        qm, qv = self.cat_conv(xa).chunk(2, dim=1) # Calculate q distribution parameters. Chunk into 2 (first z_dim is mean, second is variance)
        pfeat = self.prior(x)
        pm, pv, px = pfeat[:, :self.z_dim], pfeat[:, self.z_dim:self.z_dim*2], pfeat[:, self.z_dim*2:]
        x = x + px

        z = draw_gaussian_diag_samples(qm, qv)
        kl = gaussian_analytical_kl(qm, pm, qv, pv)

        z = self.z_conv(z)
        x = x + z
        x = self.out_res(x)

        return x, kl

    def sample(self, x):
        pfeat = self.prior(x)
        pm, pv, px = pfeat[:, :self.z_dim], pfeat[:, self.z_dim:self.z_dim*2], pfeat[:, self.z_dim*2:]
        x = x + px

        z = draw_gaussian_diag_samples(pm, pv)

        x = x + self.z_conv(z)
        x = self.out_res(x)
        return x

class DecoderBlock(HelperModule):
    def build(self, in_dim, middle_width, z_dim, nb_td_blocks, upscale_rate):
        self.upscale_rate = upscale_rate
        self.td_blocks = nn.ModuleList([
            TopDownBlock(in_dim, middle_width, z_dim)
        for _ in range(nb_td_blocks)])

    def forward(self, x, a):
        x = F.interpolate(x, scale_factor=self.upscale_rate)
        block_kl = []
        for b in self.td_blocks:
            x, kl = b(x, a)
            block_kl.append(kl)
        return x, block_kl

    def sample(self, x):
        x = F.interpolate(x, scale_factor=self.upscale_rate)
        for b in self.td_blocks:
            x = b.sample(x)
        return x

class Decoder(HelperModule):
    def build(self, in_dim, middle_width, out_dim, z_dim, nb_decoder_blocks, nb_td_blocks=3, upscale_rate=2):
        self.dec_blocks = nn.ModuleList([
            DecoderBlock(in_dim, middle_width, z_dim, nb_td_blocks, 1 if i == 0 else upscale_rate)
         for i in range(nb_decoder_blocks)])
        self.in_dim = in_dim
        self.out_conv = ConvBuilder.b3x3(in_dim, out_dim)

        for bd in self.dec_blocks:
            for bt in bd.td_blocks:
                bt.z_conv.weight.data *= np.sqrt(1 / (nb_decoder_blocks*nb_td_blocks))
                bt.out_res.conv[-1].weight.data *= np.sqrt(1 / (nb_decoder_blocks*nb_td_blocks))

    def forward(self, activations):
        activations = activations[::-1]
        x = None
        decoder_kl = []
        for i, b in enumerate(self.dec_blocks):
            a = activations[i]
            if x == None:
                x = torch.zeros_like(a)
            x, block_kl = b(x, a)
            decoder_kl.extend(block_kl)

        x = self.out_conv(x)
        return x, decoder_kl

    def sample(self, nb_samples):
        x = None
        for b in self.dec_blocks:
            if x == None:
                x = torch.zeros(nb_samples, self.in_dim, 4, 4).to('cuda') # TODO: Variable device and size
            x = b.sample(x)
        x = self.out_conv(x)
        return x

VDVAE

In [None]:
class VAE(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, z_dim, nb_blocks=4, nb_res_blocks=3, scale_rate=2):
        self.encoder = Encoder(in_dim, hidden_width, middle_width, nb_blocks, nb_res_blocks=nb_res_blocks, downscale_rate=scale_rate)
        self.decoder = Decoder(hidden_width, middle_width, in_dim, z_dim, nb_blocks, nb_td_blocks=nb_res_blocks, upscale_rate=scale_rate)

    def forward(self, x):
        activations = self.encoder(x)
        y, decoder_kl = self.decoder(activations)
        return y, decoder_kl

    def sample(self, nb_samples):
        return self.decoder.sample(nb_samples)

### Training and Evaluation

In [None]:
HPS = get_parameters('mnist')


def load_dataset(dataset, batch_size):
    if dataset == 'cifar10':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=dataset_transforms)
    elif dataset == 'stl10':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = torchvision.datasets.STL10('data', split='train+unlabeled', download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.STL10('data', split='test', download=True, transform=dataset_transforms)
    elif dataset == 'mnist':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.MNIST('data', train=False, download=True, transform=dataset_transforms)
    else:
        error(f"Unrecognized dataset '{dataset}'! Exiting..")
        exit()

    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=batch_size)

    return train_loader, test_loader

def vae_loss(x, model, crit):
    x = x.to(device)
    y, decoder_kl = model(x)
    rl = crit(x, y).mean(dim=(1,2,3))
    rpp = torch.zeros_like(rl)
    for k in decoder_kl:
        rpp += k.sum(dim=(1,2,3))
    rpp /= np.prod(x.shape[1:])
    elbo = (rpp + rl*10).mean()
    return y, elbo, rl.mean(), rpp.mean()

def train(model, loader, optim, crit, device):
    total_loss, r_loss, kl_loss = 0.0, 0.0, 0.0
    model.train()
    for x, _ in tqdm(loader):
        optim.zero_grad()
        _, elbo, rl, kl = vae_loss(x, model, crit)
        elbo.backward()
        optim.step()

        total_loss += elbo
        r_loss += rl
        kl_loss += kl
    return total_loss / len(loader), r_loss / len(loader), kl_loss / len(loader)

def evaluate(model, loader, optim, crit, device, img_id=None):
    sample = model.sample(4)
    save_image(sample, f"imgs/vdvae-sample-{img_id}.png", normalize=True, value_range=(-1, 1))

    total_loss, r_loss, kl_loss = 0.0, 0.0, 0.0
    with torch.no_grad():
        model.eval()
        for i, (x, _) in enumerate(tqdm(loader)):
            optim.zero_grad()
            y, elbo, rl, kl = vae_loss(x, model, crit)

            total_loss += elbo
            r_loss += rl
            kl_loss += kl

            if img_id != None and i == 0:
                save_image(y, f"imgs/eval-recon-{img_id}.png", normalize=True, value_range=(-1, 1))

        return total_loss / len(loader), r_loss / len(loader), kl_loss / len(loader)

In [None]:
# torch.autograd.set_detect_anomaly(True)
device = get_device(HPS.cuda)
train_loader, test_loader = load_dataset(HPS.dataset, HPS.batch_size)

model = VAE(HPS.in_channels,
            HPS.h_width,
            HPS.m_width,
            HPS.z_dim,
            nb_blocks=HPS.nb_blocks,
            nb_res_blocks=HPS.nb_res_blocks,
            scale_rate=HPS.scale_rate
).to(device)

info(f"Number of trainable parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
optim = torch.optim.Adam(model.parameters(), lr=HPS.lr, weight_decay=HPS.decay)
crit = torch.nn.MSELoss(reduction='none')

save_id = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

nb_iterations = 0
for ei in tqdm(itertools.count(), desc="Epoch"):
    train_loss, r_loss, kl_loss = train(model, train_loader, optim, crit, device)
    nb_iterations += len(train_loader)

    info(f"training, epoch {ei+1} \t iter: {nb_iterations} \t loss: {train_loss} \t r_loss {r_loss} \t kl_loss {kl_loss}")
    eval_loss, r_loss, kl_loss = evaluate(model, test_loader, optim, crit, device, img_id=str(ei).zfill(4))
    info(f"evaluate, epoch {ei+1} \t iter: {nb_iterations} \t loss: {eval_loss} \t r_loss {r_loss} \t kl_loss {kl_loss}")

    if HPS.checkpoint > 0 and ei > 0 and ei % HPS.checkpoint == 0:
        torch.save(model.state_dict(), f"saved_checkpoints/{save_id}-vdvae-{str(ei).zfill(4)}.pt")

    if nb_iterations > HPS.nb_iterations:
        info("Maximum iterations reached. Exiting..")
        break

torch.save(model.state_dict(), f"saved_checkpoints/{save_id}-vdvae-final.pt")

### Testing site

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE(3, 64, 32, 32, nb_blocks=6).to(device)
x = torch.randn(1, 3, 256, 256).to(device)
y, kls = vae(x)