# β-VAE trained on Oxford Flower Image Dataset

The purpose of this notebook is simply to generate colorful images based on images of flowers
contained in the well-known Oxford Flower Image Dataset.

### Disclaimer regarding the code

The code has been written in a relatively verbose free style for studying purposes.
It is not production-ready and it is not to be regarded as research-related or blog-article code.
Inline comments have been written to verbosely describe aspects that came up during coding.
Assertions have been added to confirm certain conditions along the way.

The author is aware that production-ready code on the other hand would usually, for example,
follow principles of clean code, with automated formatting, linting, spell-checking, etc.,
contain python type-hints in a much more consistent manner,
be automatically tested via test cases written with pytest or similar libraries,
be modularized and packaged properly, with selected functions moved to 1st-party libraries,
allow for proper packaging, export, import, etc. of readily trained models,
be compatible with open source or proprietary machine learning pipelines (think MLOps, DevOps),
be integrated with respective tooling to visualize metrics about model performance,
ship together with benchmarks, examples, cook book or user guide, etc.,
be optimized to some extend.

In [None]:
# in this code block we have all our imports

# PyTorch
import torch
torch.manual_seed(42)

# to load PyTorch image dataset
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader

# to transformation images during loading of PyTorch image dataset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

# to store resulting images
#from torchvision.utils import save_image

# PyTorch neural network stuff
import torch.nn as nn

# PyTorch functions
import torch.nn.functional as F

# to generate filenames for storing trained models or resulting images
from datetime import datetime

# to plot images or graphs
import matplotlib.pyplot as plt

# misc
import numpy as np
#import random



In [None]:
_# see if GPUs are available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = str(device).upper()
print(f"This computation is running on {device_str}.")


In [None]:
# in this code block we define variables mostly regarding dimensionalities

# root path of the dataset
# to where it is downloaded via the PyTorch dataset facility
datasets_root_path = '../../../Datasets/PyTorch'

# size of an input image in pixels
# as it goes as an input image to the input of the VAE's encoder
# note that in this case the images are square
# but the VAE code could also work with non-square images
n_input_image_pixels_height = 136
# for now we use smaller images for faster training during development
#n_input_image_pixels_height = 96
n_input_image_pixels_width = n_input_image_pixels_height

# number of channels in an image
# in this case one channel each for red, green, blue
n_image_channels = 3

# number of feature images obtained from one image after the convolutional layers in the encoder
n_dim_feature_images = 32

# number of pixels the convolutions cut away from the image height and width
n_pixels_conv_cutaway = 8

# size of an output image in pixels
# as it comes as a generated image from the output of the VAE's decoder
n_output_image_pixels_height = n_input_image_pixels_height - n_pixels_conv_cutaway
n_output_image_pixels_width = n_input_image_pixels_width - n_pixels_conv_cutaway

# number of features before mapping to the latent space
n_dim_features = n_dim_feature_images * n_output_image_pixels_height * n_output_image_pixels_width

# the number of dimensions of the latent space
# aka. number of latent variables
# which is then also the number of values in a vector z in the latent space
n_dim_latent_space=64

# should a DataLoader shuffle the input images
# or rather always yield them in the same order
shuffle = False


In [None]:
# in this code block we define a custom transformation regarding gamma, brightness, saturation

class CustomImageTransform:
    """Adust an image in a custom way."""

    def __call__(self, img):
        img = TF.adjust_gamma(img, 1.6, 1.1)
        img = TF.adjust_brightness(img, 0.7)
        img = TF.adjust_saturation(img, 1.5)
        return img


In [None]:
# in this code block we define our image transformation pipeline treating the images a bit

# note that some magic numbers in here
# are just temporary sizes of input images as they go through the pre-processing pipeline
# and that we could extract these magic numbers into variables or even calculate them
# but for the sake of simplicity we just use magic numbers here

# pre-processing pipeline to pimp these images from 2009 and earlier a bit
transformations_enhance = transforms.Compose([
                                      transforms.Resize(384, transforms.InterpolationMode.BICUBIC, antialias=True),
                                      transforms.CenterCrop([320, 320]),
                                      transforms.GaussianBlur(kernel_size=3.0,sigma=2.0),
                                      CustomImageTransform(),
                                      transforms.RandomAutocontrast(p=1.0),
                                      transforms.RandomAdjustSharpness(sharpness_factor=4, p=1.0),
                                      transforms.GaussianBlur(kernel_size=1.5,sigma=1.5),
                                      transforms.Resize(256, transforms.InterpolationMode.BICUBIC, antialias=True),
                                      transforms.CenterCrop([n_input_image_pixels_height, n_input_image_pixels_width]),
                                      transforms.ToTensor()
                                    ])

# for faster training switch to this much simpler pre-processing pipeline
transformations_justcrop = transforms.Compose([
                                      transforms.Resize(256, transforms.InterpolationMode.BICUBIC, antialias=True),
                                      transforms.CenterCrop([n_input_image_pixels_height, n_input_image_pixels_width]),
                                      transforms.ToTensor()
                                    ])

transformations = transformations_enhance


In [None]:
# in this code block we define the train, val, test dataset

dataset_train = datasets.Flowers102(
    root=datasets_root_path,
    split='train',
    download=True,
    transform=transformations
)

dataset_val = datasets.Flowers102(
    root=datasets_root_path,
    split='val',
    download=True,
    transform=transformations
)

dataset_test = datasets.Flowers102(
    root=datasets_root_path,
    split='test',
    download=True,
    transform=transformations
)


In [None]:
# in this code block we just run such a custom data loader once to view an example input image

view_example_image = False

if view_example_image:
    dataloader_example1 = DataLoader(dataset_train, batch_size=1, shuffle=shuffle)
    images_example1, labels_example1 = next(iter(dataloader_example1))
    print(f"images_example1.size()={images_example1.size()}")
    print(f"labels_example1.size()={labels_example1.size()}")
    image_example1 = images_example1[0]
    label_example1 = labels_example1[0]
    print(f"image_example1.size()={image_example1.size()}")
    print(f"image_example1={image_example1}")
    print(f"label_example1.size()={label_example1.size()}")
    print(f"label_example1={label_example1}")
    # use .permute(1, 2, 0) to move the channel values from the first to the last dimension
    image_example1 = image_example1
    plt.imshow(image_example1.permute(1, 2, 0))


In [None]:
# in this code block we define a helper function to generate filenames for storing trained models

def create_filename(prefix: str, extension: str):
    now = datetime.now()
    now_str = now.strftime("%Y%m%d_%H%M%S")
    filename = f"{prefix}_{now_str}.{extension}"
    return filename


In [None]:
# in this code block we define our Convolutional Beta Variational Auto-Encoder

# convolutional β-VAE
class VAE(nn.Module):
    def __init__(self,
                 n_image_channels,
                 n_dim_feature_images,
                 conv_kernel_size,
                 n_dim_features,
                 n_dim_latent_space,
                 n_output_image_pixels_height,
                 n_output_image_pixels_width,
                 beta: torch.Tensor,
                 debug: bool,
                 trace: bool):
        super(VAE, self).__init__()
        
        self.n_image_channels = n_image_channels
        self.n_dim_feature_images = n_dim_feature_images
        self.n_dim_inner_conv = int(self.n_dim_feature_images / 2)
        print(f"self.n_dim_inner_conv={self.n_dim_inner_conv}")
        self.conv_kernel_size = conv_kernel_size
        self.n_dim_features = n_dim_features
        self.n_dim_latent_space = n_dim_latent_space
        self.n_output_image_pixels_height = n_output_image_pixels_height
        self.n_output_image_pixels_width = n_output_image_pixels_width
        self.beta = beta
        self.debug = debug
        self.trace = trace

        print(f"Creating new Beta-VAE instance with beta={beta.cpu()}.")

        # encoder layers

        # input to the encoder passes through
        # two convolutional layers to work out important features in each input image
        self.encoder_layer_1_conv = nn.Conv2d(in_channels=self.n_image_channels, out_channels=self.n_dim_inner_conv, kernel_size=self.conv_kernel_size)
        torch.nn.init.xavier_uniform_(self.encoder_layer_1_conv.weight)
        self.encoder_layer_2_conv = nn.Conv2d(in_channels=self.n_dim_inner_conv, out_channels=self.n_dim_feature_images, kernel_size=self.conv_kernel_size)
        torch.nn.init.xavier_uniform_(self.encoder_layer_2_conv.weight)
        # # note that in the forward pass of the encoder
        # the output of the last convolutional layer
        # is first reshaped into a vector before it is then passed to the first linear layer

        # and two linear layers to compute the parameters of distributions in the latent space where
        # computed vectors mu and sigma_log parameterize these distributions of latent variables
        # in fact there is one linear layer to compute the vector of means mu
        self.encoder_layer_3_linear = nn.Linear(in_features=self.n_dim_features, out_features=self.n_dim_latent_space)
        torch.nn.init.xavier_uniform_(self.encoder_layer_3_linear.weight)
        # and one linear layer to compute the vector of log std sigma_log
        # note that the output of this layer is interpreted/used as
        # the natural logarithm of each computed standard deviation
        # the reason why it is interpreted as log std instead of just std is that
        # in this way the network can more easily cover a broad range of values
        # so the natural logarithm (log aka. ln) appears here sort of as a numerical compression
        self.encoder_layer_4_linear = nn.Linear(in_features=self.n_dim_features, out_features=self.n_dim_latent_space)
        torch.nn.init.xavier_uniform_(self.encoder_layer_4_linear.weight)

        # decoder layers

        # z is first passed through a linear layer to go from the latent space to the feature space
        # note that in the forward pass of the decoder
        # the output of this layer needs to be reshaped
        self.decoder_layer_1_linear = nn.Linear(in_features=self.n_dim_latent_space, out_features=self.n_dim_features)
        torch.nn.init.xavier_uniform_(self.decoder_layer_1_linear.weight)
        # before it then is passed to the first "de-convolutional" layer
        self.decoder_layer_2_deconv = nn.ConvTranspose2d(in_channels=self.n_dim_feature_images, out_channels=self.n_dim_inner_conv, kernel_size=self.conv_kernel_size)
        torch.nn.init.xavier_uniform_(self.decoder_layer_2_deconv.weight)
        self.decoder_layer_3_deconv = nn.ConvTranspose2d(in_channels=self.n_dim_inner_conv, out_channels=self.n_image_channels, kernel_size=self.conv_kernel_size)
        torch.nn.init.xavier_uniform_(self.decoder_layer_3_deconv.weight)


    def encoder(self, x):
        assert device_str == 'CPU' or x.is_cuda
        # the convolutional layers are traversed one after the other
        # thereby the activation function is mainly used to add non-linearity as usual
        # in this case GELU (Gaussian Error Linear Units) is used
        # in fact the tanh-based approximation for increased computational speed

        if self.debug:
            print(f"x.shape={x.shape} before convolutional layers.")

        x = F.gelu(input=self.encoder_layer_1_conv(x), approximate='tanh')
        assert device_str == 'CPU' or x.is_cuda
        if self.debug:
            print(f"x.shape={x.shape} after 1st convolutional layer.")

        x = F.gelu(input=self.encoder_layer_2_conv(x), approximate='tanh')
        assert device_str == 'CPU' or x.is_cuda
        if self.debug:
            print(f"x.shape={x.shape} after 2nd convolutional layer.")

        # now make a vector out of what we currently have computed
        x = x.view(-1, self.n_dim_features)
        assert device_str == 'CPU' or x.is_cuda
        if self.debug:
            print(f"x.shape={x.shape} after reshaping.")

        # compute vector of means
        mu = self.encoder_layer_3_linear(x)
        assert device_str == 'CPU' or mu.is_cuda
        assert self.n_dim_latent_space == mu.shape[1]
        if self.debug:
            print(f"mu.shape={mu.shape}")

        # compute vector of log variances
        sigma_log = self.encoder_layer_4_linear(x)
        assert device_str == 'CPU' or sigma_log.is_cuda
        assert self.n_dim_latent_space == sigma_log.shape[1]
        if self.debug:
            print(f"sigma_log.shape={sigma_log.shape}")
    
        return mu, sigma_log


    def re_parametrization_trick(self, mu, sigma_log):
        # do the VAE re-parametrization trick
        # z = mu + (sigma * epsilon)
        # where epsilon will be sampled from a standard normal distribution

        # first scale the log standard deviations into standard deviations
        # by applying the exponential function with base e
        sigma = torch.exp( sigma_log / 2)

        # note that the vector sigma is only passed to the PyTorch function randn_like
        # to tell it which dimensionality the result should have
        epsilon = torch.randn_like(sigma)
        epsilon.to(device)

        assert self.n_dim_latent_space == mu.shape[1]
        assert self.n_dim_latent_space == sigma_log.shape[1]
        assert sigma.shape == sigma_log.shape
        assert epsilon.shape == sigma.shape

        assert device_str == 'CPU' or mu.is_cuda
        assert device_str == 'CPU' or sigma_log.is_cuda
        assert device_str == 'CPU' or sigma.is_cuda
        assert device_str == 'CPU' or epsilon.is_cuda

        # so here is the linear combination of the vector of learned means
        # and the corresponding vector of learned standard deviations
        # with the scaling factor epsilon being a random fraction
        # drawn from the standard normal distribution N(0,1)
        # to obtain the vector z
        random_deviation = sigma * epsilon
        assert device_str == 'CPU' or random_deviation.is_cuda
        z = mu + random_deviation
        assert device_str == 'CPU' or z.is_cuda

        if self.trace:
            print(f"sigma=\n{sigma.cpu()}")
            print(f"epsilon=\n{epsilon.cpu()}")
            print(f"random_deviation=\n{random_deviation.cpu()}")
            print(f"mu=\n{mu.cpu()}")
            print(f"z=\n{z.cpu()}")

        if self.debug:
            print(f"z.shape={z.cpu().shape}")

        # to obtain values so to say sampled from the learned distributions in the latent space
        return z


    def decoder(self, z):
        # decoding
        x = F.gelu(input=self.decoder_layer_1_linear(z), approximate='tanh')
        assert device_str == 'CPU' or x.is_cuda
        x = x.view(-1, self.n_dim_feature_images, self.n_output_image_pixels_height, self.n_output_image_pixels_width)
        assert device_str == 'CPU' or x.is_cuda
        x = F.gelu(input=self.decoder_layer_2_deconv(x), approximate='tanh')
        assert device_str == 'CPU' or x.is_cuda
        x = torch.sigmoid(self.decoder_layer_3_deconv(x))
        assert device_str == 'CPU' or x.is_cuda
        return x


    def forward(self, x):
        # encoding -> re-parametrization -> decoding
        # this is a full forward pass mostly only used in training of the VAE
        mu, log_sigma = self.encoder(x)
        assert device_str == 'CPU' or mu.is_cuda
        assert device_str == 'CPU' or log_sigma.is_cuda
        z = self.re_parametrization_trick(mu, log_sigma)
        assert device_str == 'CPU' or z.is_cuda
        output = self.decoder(z)
        assert device_str == 'CPU' or output.is_cuda
        # returns decoder output, means, log variances for subsequent computation of loss
        return output, mu, log_sigma


In [None]:
# in this code block we initialize a β-VAE for subsequent training

print(f"n_image_channels={n_image_channels}")
print(f"n_dim_feature_images={n_dim_feature_images}")
print(f"n_dim_features={n_dim_features}")
print(f"n_dim_latent_space={n_dim_latent_space}")

#torch.cuda.empty_cache()

beta = torch.Tensor([4.0])
beta = beta.to(device)
assert beta.shape == (1,)
assert device_str == 'CPU' or beta.is_cuda

vae = VAE(n_image_channels=n_image_channels,
          n_dim_feature_images=n_dim_feature_images,
          conv_kernel_size=5,
          n_dim_features=n_dim_features,
          n_dim_latent_space=n_dim_latent_space,
          n_output_image_pixels_height=n_output_image_pixels_height,
          n_output_image_pixels_width=n_output_image_pixels_width,
          beta=beta,
          debug=False,
          trace=False)

vae.to(device)


In [None]:
# in this code block we define a procedure to train a β-VAE

def train_vae_variant_1(vae: VAE,
                        lr: float,
                        n_epochs: int,
                        n_epochs_max: int,
                        dataloader_train: DataLoader,
                        n_batches_max: int,
                        debug: bool):
    # initialize the optimizer
    optimizer = torch.optim.Adam(vae.parameters(), lr=lr)

    # initialize the KL-divergence function
    kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

    # iterate
    for epoch in range(n_epochs):
        if debug:
            print(f"train epoch={epoch} starting.")    

        # iterate over the batches of images
        dataloader_train_iterator = iter(dataloader_train)
        batch_ctr = 0
        batch_ctr_max_reached = False
        output_images = None
        # note that the underscore is just to ignore any labels the DataLoader provides
        # since auto-encoders are generally unsupervised and do not need labels
        # however the dataset does contain labels so we just ignore them here
        # the underscore is just for discarding the labels
        for batch_of_input_images, _ in dataloader_train_iterator:
            #if debug and epoch % 10 == 0:
            #    print(f"train epoch={epoch} batch={batch_ctr} starting.")    
            
            # move the batch to GPU if available
            batch_of_input_images = batch_of_input_images.to(device)
                
            # forward pass to generate image and obtain the vectors mu and sigma_log
            output_images, mu, sigma_log = vae(batch_of_input_images)

            #print(f"output_images.is_cuda={output_images.is_cuda}")
            assert device_str == 'CPU' or output_images.is_cuda
            #print(f"mu.is_cuda={mu.is_cuda}")
            assert device_str == 'CPU' or mu.is_cuda
            assert device_str == 'CPU' or vae.beta.is_cuda
            #print(f"sigma_log.is_cuda={sigma_log.is_cuda}")
            assert device_str == 'CPU' or sigma_log.is_cuda

            # VAE loss is binary cross-entropy loss as reconstruction term
            bce_loss = F.binary_cross_entropy(output_images, batch_of_input_images, reduction='sum')
            #print(f"bce_loss.is_cuda={bce_loss.is_cuda}")
            assert device_str == 'CPU' or bce_loss.is_cuda
            # combined with Kullback–Leibler divergence as regularization
            # note that this form of the KL-divergence is simplified
            # based on the facts that the distributions are normal distributions
            # and the prior distributions are standard normal distributions
            kl_divergence = 0.5 * torch.sum(-1 - sigma_log + mu.pow(2) + sigma_log.exp())
            #print(f"kl_divergence.is_cuda={kl_divergence.is_cuda}")
            assert device_str == 'CPU' or kl_divergence.is_cuda
            # now β-VAE loss is just that with the constant beta > 1 emphasizing regularization
            # in order to force better disentanglement of the features
            # by improving the properties of the learned distributions in the latent space
            loss = bce_loss + (vae.beta * kl_divergence)
            #print(f"loss.is_cuda={loss.is_cuda}")
            assert device_str == 'CPU' or loss.is_cuda

            #print(f"train epoch={epoch} batch={batch_ctr} loss={loss.item()}"
            #      f" = bce_loss={bce_loss.item()} + (beta={vae.beta.item()} * kl_divergence={kl_divergence.item()})")    
    
            # back-propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_ctr = batch_ctr + 1
            if(n_batches_max <= batch_ctr):
                batch_ctr_max_reached = True
                break

        print(f"train epoch={epoch} loss = {loss.cpu()} = {bce_loss.cpu()} + {vae.beta.cpu()} * {kl_divergence.cpu()})")
        if(batch_ctr_max_reached):
            print(f"train reached maximum number of batches {n_batches_max}.")
            break
        if(n_epochs_max <= epoch):
            print(f"train reached maximum number of epochs {n_epochs_max}.")
            break

    # ensure output Tensor is detached from autograd before returning
    output_images = output_images.detach()
    return output_images


In [None]:
lr = 0.0003
n_epochs = 30
n_epochs_max = 30
batch_size_train = 32
n_batches_max = 100000000
debug = False

assert 2 <= batch_size_train
assert batch_size_train <= 128

dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size_train, shuffle=True, pin_memory=True)

output_images = train_vae_variant_1(vae=vae,
                                    lr=lr,
                                    n_epochs=n_epochs,
                                    n_epochs_max=n_epochs_max,
                                    dataloader_train=dataloader_train,
                                    n_batches_max=n_batches_max,
                                    debug=debug)


In [None]:
output_images_iterator = iter(output_images)
output_image_0 = output_images[0].cpu()
print(f"output_image_0.is_cuda={output_image_0.is_cuda}")
print(f"output_image_0.shape={output_image_0.shape}")
print(f"output_image_0={output_image_0}")
plt.imshow(output_image_0.permute(1, 2, 0))


In [None]:
plt.imshow(next(output_images_iterator).cpu().permute(1, 2, 0))


In [None]:
# in this code block we prepare a forward pass with test images

dataloader_test = DataLoader(dataset_test, batch_size=128, shuffle=True, pin_memory=True)
dataloader_test_iterator = iter(dataloader_test)


In [None]:
# in this code block we do a forward pass with test images

batch_of_input_images_test, _ = next(dataloader_test_iterator)
batch_of_input_images_test = batch_of_input_images_test.to(device)
batch_of_input_images_test_clone = batch_of_input_images_test.clone()
batch_of_input_images_test_clone_iterator = iter(batch_of_input_images_test_clone)
output_images_test, mu_test, log_sigma_test = vae.forward(batch_of_input_images_test)
output_images_test_iterator = iter(output_images_test)


In [None]:
# in this code block we show the next test input image

# given all above code block ran
# the user can run this code block to see the next input image
# and subsequently run the next code block to see the corresponding result image
# so the user may alternately run this code block and the next code block
# to get this experience

input_image_test_clone = next(batch_of_input_images_test_clone_iterator).detach().cpu()
plt.imshow(input_image_test_clone.permute(1, 2, 0))

In [None]:
# in this code block we show each test output image per run of this cell

output_image_test = next(output_images_test_iterator).detach().cpu()
plt.imshow(output_image_test.permute(1, 2, 0))