# Hierarchical VAE


This practice session allows to train a top-down hierarchical VAE.

This is a toy implementation of the NVAE model with a simple MSE reconstruction loss and only three stages.

![alternatvie text](https://raw.githubusercontent.com/generativemodelingmva/generativemodelingmva.github.io/main/tp2324/toynvae_framework_full.png)


Sources:
* Toy implementation from: https://github.com/GlassyWing/nvae
* NVAE official implementation: https://github.com/NVlabs/NVAE
* NVAE paper: "NVAE: A Deep Hierarchical Variational Autoencoder", Arash Vahdat and Jan Kautz (NeurIPS 2020 Spotlight Paper) https://arxiv.org/abs/2007.03898
* CelebA validation set (used for training): https://www.kaggle.com/datasets/jessicali9530/celeba-dataset




# Download files

In [1]:
# do just once:
!wget -nc -O celeba64png_val.zip 'https://www.dropbox.com/scl/fi/3d2le2wlu61nzkbxymfm3/celeba64png_val.zip?rlkey=ckesud01kwb8tualsd3s8zg5d'
!unzip -nq celeba64png_val.zip
!ls val | wc -l # there should be 19867 files

--2025-12-08 09:37:12--  https://www.dropbox.com/scl/fi/3d2le2wlu61nzkbxymfm3/celeba64png_val.zip?rlkey=ckesud01kwb8tualsd3s8zg5d
Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://uc87e89067ba8486f1c245e80b35.dl.dropboxusercontent.com/cd/0/inline/C2oYsEDSsnnYTAt9dbQE_a7wzGg9DE__2-pBK47iO_QbGTsbJbeeOr4CUrmNm6DDAeOuQSW3YhV8-P8Y8-7sO2czvqUaO-tjEh-jYzAN5_9LjeV7yPHGnWl_EjRoXAYQjKA/file# [following]
--2025-12-08 09:37:13--  https://uc87e89067ba8486f1c245e80b35.dl.dropboxusercontent.com/cd/0/inline/C2oYsEDSsnnYTAt9dbQE_a7wzGg9DE__2-pBK47iO_QbGTsbJbeeOr4CUrmNm6DDAeOuQSW3YhV8-P8Y8-7sO2czvqUaO-tjEh-jYzAN5_9LjeV7yPHGnWl_EjRoXAYQjKA/file
Resolving uc87e89067ba8486f1c245e80b35.dl.dropboxusercontent.com (uc87e89067ba8486f1c245e80b35.dl.dropboxusercontent.com)... 162.125.81.15, 2620:100:6030:15::a27d:500f
Connectin

In [2]:
import argparse
import os
import time
from glob import glob
from datetime import datetime

from IPython.display import display
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)


cuda:0


In [3]:
from PIL import Image
class PngImageFolderDataset(Dataset):

    def __init__(self, image_dir):
        self.img_paths = sorted(glob(os.path.join(image_dir, "*.png")))

    def __getitem__(self, idx):
        return torchvision.transforms.ToTensor()(Image.open(self.img_paths[idx]))

    def __len__(self):
        return len(self.img_paths)

train_ds = PngImageFolderDataset('val')

# Model




## Losses for training


In [4]:

# utils functions for training and sampling:

def kl(mu, log_var):
    """KL divergence to a standard normal prior for diagonal Gaussian q(z|x)."""
    loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - torch.exp(log_var), dim=[1, 2, 3])
    return torch.mean(loss, dim=0)


def kl_delta(delta_mu, delta_log_var, mu, log_var):
    """
    KL(q || p) between two diagonal Gaussians where
      p = N(mu, exp(log_var))
      q = N(mu + delta_mu, exp(log_var + delta_log_var)).
    Returns mean KL over the batch.
    """
    mu_q = mu + delta_mu
    log_var_q = log_var + delta_log_var

    kl_term = 0.5 * (
        log_var - log_var_q
        + torch.exp(log_var_q - log_var)
        + (mu_q - mu) ** 2 / torch.exp(log_var)
        - 1.0
    )
    return kl_term.sum(dim=[1, 2, 3]).mean()


def reparameterize(mu, std):
    z = torch.randn_like(mu) * std + mu
    return z


from torch.nn.utils import spectral_norm
def add_sn(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        return spectral_norm(m)
    else:
        return m


IndentationError: expected an indented block after function definition on line 14 (ipython-input-2056890792.py, line 18)

### Common layers

In [None]:
#from nvae.common import Swish, DecoderResidualBlock, ResidualBlock

class Swish(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)


class SELayer(nn.Module):

    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class DecoderResidualBlock(nn.Module):

    def __init__(self, dim, n_group):
        super().__init__()

        self._seq = nn.Sequential(
            nn.Conv2d(dim, n_group * dim, kernel_size=1),
            nn.BatchNorm2d(n_group * dim), Swish(),
            nn.Conv2d(n_group * dim, n_group * dim, kernel_size=5, padding=2, groups=n_group),
            nn.BatchNorm2d(n_group * dim), Swish(),
            nn.Conv2d(n_group * dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim),
            SELayer(dim))

    def forward(self, x):
        return x + 0.1 * self._seq(x)


class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self._seq = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=5, padding=2),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim), Swish(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            SELayer(dim))

    def forward(self, x):
        return x + 0.1 * self._seq(x)



### Encoder architecture

In [None]:
# Encoder:
class ConvBlock(nn.Module):

    def __init__(self, in_channel, out_channel):
        super().__init__()

        self._seq = nn.Sequential(

            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.Conv2d(out_channel, out_channel // 2, kernel_size=1),
            nn.BatchNorm2d(out_channel // 2), Swish(),
            nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channel), Swish()
        )

    def forward(self, x):
        return self._seq(x)


class EncoderBlock(nn.Module):

    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        modules = []
        for i in range(len(channels) - 1):
            modules.append(ConvBlock(channels[i], channels[i + 1]))

        self.modules_list = nn.ModuleList(modules)

    def forward(self, x):
        for module in self.modules_list:
            x = module(x)
        return x


class EncoderResidualBlock(nn.Module):

    def __init__(self, dim):
        super().__init__()

        self.seq = nn.Sequential(

            nn.Conv2d(dim, dim, kernel_size=5, padding=2),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim), Swish(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            SELayer(dim))

    def forward(self, x):
        return x + 0.1 * self.seq(x)


class Encoder(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock([3, z_dim // 16, z_dim // 8]),  # (16, 16)
            EncoderBlock([z_dim // 8, z_dim // 4, z_dim // 2]),  # (4, 4)
            EncoderBlock([z_dim // 2, z_dim]),  # (2, 2)
        ])

        self.encoder_residual_blocks = nn.ModuleList([
            EncoderResidualBlock(z_dim // 8),
            EncoderResidualBlock(z_dim // 2),
            EncoderResidualBlock(z_dim),
        ])

        self.condition_x = nn.Sequential(
            Swish(),
            nn.Conv2d(z_dim, z_dim * 2, kernel_size=1)
        )

    def forward(self, x):
        xs = []
        for e, r in zip(self.encoder_blocks, self.encoder_residual_blocks):
            x = r(e(x))
            xs.append(x)

        mu, log_var = self.condition_x(x).chunk(2, dim=1)

        return mu, log_var, xs[:-1][::-1]



### Decoder architecture with shared top-down encoder/decoder

In [None]:

# Decoder:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()

        self._seq = nn.Sequential(

            nn.ConvTranspose2d(in_channel,
                               out_channel,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            # nn.UpsamplingBilinear2d(scale_factor=2),
            # nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel), Swish(),
        )

    def forward(self, x):
        return self._seq(x)


class DecoderResidualBlock(nn.Module):

    def __init__(self, dim, n_group):
        super().__init__()

        self._seq = nn.Sequential(
            nn.Conv2d(dim, n_group * dim, kernel_size=1),
            nn.BatchNorm2d(n_group * dim), Swish(),
            nn.Conv2d(n_group * dim, n_group * dim, kernel_size=5, padding=2, groups=n_group),
            nn.BatchNorm2d(n_group * dim), Swish(),
            nn.Conv2d(n_group * dim, dim, kernel_size=1),
            nn.BatchNorm2d(dim),
            SELayer(dim))

    def forward(self, x):
        return x + 0.1 * self._seq(x)


class DecoderBlock(nn.Module):

    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        modules = []
        for i in range(len(channels) - 1):
            modules.append(UpsampleBlock(channels[i], channels[i + 1]))
        self.module_list = nn.ModuleList(modules)

    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        return x


class Decoder(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim

        # Input channels = z_channels * 2 = x_channels + z_channels
        # Output channels = z_channels
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock([z_dim * 2, z_dim // 2]),  # 2x upsample
            DecoderBlock([z_dim, z_dim // 4, z_dim // 8]),  # 4x upsample
            DecoderBlock([z_dim // 4, z_dim // 16, z_dim // 32])  # 4x uplsampe
        ])
        self.decoder_residual_blocks = nn.ModuleList([
            DecoderResidualBlock(z_dim // 2, n_group=4),
            DecoderResidualBlock(z_dim // 8, n_group=2),
            DecoderResidualBlock(z_dim // 32, n_group=1)
        ])

        # p(z_l | z_(l-1))
        self.condition_z = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(z_dim // 2),
                Swish(),
                nn.Conv2d(z_dim // 2, z_dim, kernel_size=1)
            ),
            nn.Sequential(
                ResidualBlock(z_dim // 8),
                Swish(),
                nn.Conv2d(z_dim // 8, z_dim // 4, kernel_size=1)
            )
        ])

        # p(z_l | x, z_(l-1))
        self.condition_xz = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(z_dim),
                nn.Conv2d(z_dim, z_dim // 2, kernel_size=1),
                Swish(),
                nn.Conv2d(z_dim // 2, z_dim, kernel_size=1)
            ),
            nn.Sequential(
                ResidualBlock(z_dim // 4),
                nn.Conv2d(z_dim // 4, z_dim // 8, kernel_size=1),
                Swish(),
                nn.Conv2d(z_dim // 8, z_dim // 4, kernel_size=1)
            )
        ])

        self.recon = nn.Sequential(
            ResidualBlock(z_dim // 32),
            nn.Conv2d(z_dim // 32, 3, kernel_size=1),
        )


    def forward(self, z, xs=None):
        """

        :param z: shape. = (B, z_dim, map_h, map_w)
        if xs=None: sample mode; otherwise xs is list of intermediate encoder features
        """

        B, D, map_h, map_w = z.shape

        # The init h (hidden state), can be replace with learned param, but it didn't work much
        decoder_out = torch.zeros(B, D, map_h, map_w, device=z.device, dtype=z.dtype)

        kl_losses = []

        for i in range(len(self.decoder_residual_blocks)):

            z_sample = torch.cat([decoder_out, z], dim=1)
            decoder_out = self.decoder_residual_blocks[i](self.decoder_blocks[i](z_sample))

            if i == len(self.decoder_residual_blocks) - 1: # stop if last block
                break

            mu, log_var = self.condition_z[i](decoder_out).chunk(2, dim=1) # parameter for sampling next z

            if xs is not None:
                delta_mu, delta_log_var = self.condition_xz[i](
                                torch.cat([xs[i], decoder_out], dim=1)).chunk(2, dim=1)
                kl_losses.append(kl_delta(delta_mu, delta_log_var, mu, log_var))
                mu = mu + delta_mu
                log_var = log_var + delta_log_var

            z = reparameterize(mu, torch.exp(0.5 * log_var))

        x_hat = torch.sigmoid(self.recon(decoder_out))

        return x_hat, kl_losses


    def sample(self, n_samples=32, fix_level=-1):
        """
        Sample from the hierarchical prior.

        Args:
            n_samples: number of images to sample.
            fix_level: -1 = all latents random;
                       0 = share top latent across samples;
                       1 = share top two latents across samples.
        """
        self.eval()
        device = next(self.parameters()).device
        with torch.no_grad():
            B = n_samples
            # top latent (z0)
            if fix_level >= 0:
                z0_single = torch.randn(1, self.z_dim, 2, 2, device=device)
                z = z0_single.expand(B, -1, -1, -1)
            else:
                z = torch.randn(B, self.z_dim, 2, 2, device=device)

            decoder_out = torch.zeros_like(z)
            fixed_latents = {}

            for i in range(len(self.decoder_residual_blocks)):
                z_sample = torch.cat([decoder_out, z], dim=1)
                decoder_out = self.decoder_residual_blocks[i](self.decoder_blocks[i](z_sample))

                if i == len(self.decoder_residual_blocks) - 1:
                    break

                mu, log_var = self.condition_z[i](decoder_out).chunk(2, dim=1)
                std = torch.exp(0.5 * log_var)

                level = i + 1  # next latent level being sampled
                if fix_level >= level:
                    if level not in fixed_latents:
                        fixed_latents[level] = reparameterize(mu, std)[:1]
                    z = fixed_latents[level].expand(B, -1, -1, -1)
                else:
                    z = reparameterize(mu, std)

            x_hat = torch.sigmoid(self.recon(decoder_out))
            return x_hat


In [None]:

class NVAE(nn.Module):

    def __init__(self, z_dim, img_dim):
        super().__init__()

        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)

    def forward(self, x):
        """

        :param x: Tensor. shape = (B, C, H, W)
        :return:
        """

        mu, log_var, xs = self.encoder(x)

        # (B, D_Z)
        z = reparameterize(mu, torch.exp(0.5 * log_var)) # sampling top latent variable

        decoder_output, kl_losses = self.decoder(z, xs)

        kl_losses = [kl(mu, log_var)]+kl_losses

        recon_loss = nn.MSELoss(reduction='sum')(decoder_output, x)/decoder_output.shape[0]

        return decoder_output, recon_loss, kl_losses



### Visualization functions

In [None]:
import torchvision

def imshow(img):
    #img = img*0.5 + 0.5     # unnormalize
    pil_img = torchvision.transforms.functional.to_pil_image(img)
    display(pil_img)
    #print("Image size (h x w): ",  pil_img.height, "x", pil_img.width)
    return(pil_img)

batch_size = 128

model = NVAE(z_dim=512, img_dim=(64, 64)).to(device)

def show_decoder_output(z=None):
  # provide random latent code as option to see evolution
  with torch.no_grad():
    if z==None:
      z = torch.randn((batch_size,512,2,2)).to(device)
      # We use full batch size and then select first 32 images
    genimages = model.decoder(z)[0]
    pil_img = imshow(torchvision.utils.make_grid(genimages[:32,:,:,:].to('cpu'),nrow=8))
  return(pil_img)

show_decoder_output();



# Training

### Exercise 2:
1. Display an image of 4x8 portraits from the training dataset.
1. Read the model architecture and train for 5 epochs.
1. How do you explain the difference of images generated with and without model.eval()?


In [None]:

# Display an image grid (4x8) from the training dataset.
from torchvision.utils import make_grid

preview_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
images = next(iter(preview_loader))

plt.figure(figsize=(8, 4))
plt.axis('off')
plt.imshow(make_grid(images[:32], nrow=8).permute(1, 2, 0))
plt.show()


In [None]:
epochs = 5
batch_size = 128
n_cpu = 2

train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=n_cpu)

model = NVAE(z_dim=512, img_dim=(64, 64)).to(device)
# apply Spectral Normalization
model.apply(add_sn)


zshow = torch.randn((batch_size,512,2,2)).to(device)


# folder for checkpoints and visualization:
now = datetime.now()
dt_string = now.strftime("%Y%m%d_%H%M%S")
checkpoints_dir = dt_string+"_checkpoints"
os.makedirs(checkpoints_dir, exist_ok=True)
outputs_dir = dt_string+"_outputs"
os.makedirs(outputs_dir, exist_ok=True)


optimizer = torch.optim.Adamax(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15, eta_min=1e-4)

for epoch in range(epochs):
    model.train()

    for i, image in enumerate(train_dataloader):
        optimizer.zero_grad()
        image = image.to(device)
        image_recon, recon_loss, kl_losses = model(image)

        kl_f = 1.
        loss = recon_loss + kl_f*sum(kl_losses)
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            log_str = "\r---- [Epoch %d/%d, Step %d/%d] loss: %.6f----" % (
            epoch, epochs, i, len(train_dataloader), loss.item())
            print(log_str)
            with torch.no_grad():
                pil_img = show_decoder_output(zshow)
                imgpath = os.path.join(outputs_dir, "nvae_simple_loss_epoch_"+str(epoch).zfill(3)+"_step_"+str(i).zfill(4)+".png")
                pil_img.save(imgpath)
                model.eval()
                pil_img = show_decoder_output(zshow)
                imgpath = os.path.join(outputs_dir, "nvae_simple_loss_epoch_"+str(epoch).zfill(3)+"_step_"+str(i).zfill(4)+"_eval.png")
                pil_img.save(imgpath)
                model.train()

    # end epoch: save checkpoint:
    scheduler.step()
    if epoch%5==4:
      checkpoint_path = os.path.join(checkpoints_dir, "nvae_simple_loss_epoch_"+str(epoch).zfill(3)+".pth")
      torch.save(model.state_dict(), checkpoint_path)



### Load pretrained model

In [None]:
# Load pretrained model:
!wget -nc -O nvae_simple_loss_epoch_199.pth 'https://www.dropbox.com/scl/fi/0rsjcx78w338nj4ie6lun/nvae_simple_loss_epoch_199.pth?rlkey=8ok3htl9gp3ywmpqu02xdfo87'
checkpoint_path = 'nvae_simple_loss_epoch_199.pth'
model = NVAE(z_dim=512, img_dim=(64, 64)).to(device)
model.apply(add_sn)
model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
model.eval()

zshow = torch.randn((batch_size,512,2,2)).to(device)
show_decoder_output(zshow);



### Exercise 3:
1. Complete the function def sample(self, n_samples=32, fix_level=-1) of the encoder class so that it samples images with common realizations up to fix_level (begin by implementing and testing independent sampling).
1. Test this function and verify the hierarchical VAE encodes the images hierarchically.
1. (Bonus question) Sample 20k images and compute FID against the celeba test set (available here https://www.dropbox.com/scl/fi/in8hqobto2p2k2baiwi5x/celeba64png_test.zip?rlkey=jmisq9swucjwjyv69ftwwks06)





In [None]:

# Sampling with the (pre)trained hierarchical VAE
model.eval()
with torch.no_grad():
    # change fix_level to 0 or 1 to share higher-level latents across samples
    samples = model.decoder.sample(n_samples=32, fix_level=-1)
    grid = torchvision.utils.make_grid(samples.cpu(), nrow=8)
    display(torchvision.transforms.functional.to_pil_image(grid))
