In [None]:
import os

# Matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

import torch
from torchvision import transforms as T
import lightning.pytorch as pl
from torch.utils.data import random_split
# conflicts with fastai: DataLoader
        
# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10
from torchvision import transforms

import numpy as np
        
from PIL import Image

from torch import nn
from torch.nn import functional as F
from fastai.vision.all import *

from typing import List, Callable, Union, Any, TypeVar, Tuple, Dict
Tensor = TypeVar('torch.tensor')

In [None]:
def show_results( originals, reconstructed, samples ):
    n = len(originals)

    plt.figure(figsize=(20, 4))
    for i in range(n):
        # display original
        ax = plt.subplot(3, n, i+1)
        plt.imshow(np.transpose(originals[i], (1,2,0)) )
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(3, n, i + n+1)
        plt.imshow(np.transpose(reconstructed[i], (1,2,0)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display samples
        ax = plt.subplot(3, n, i + n + n+1)
        plt.imshow(np.transpose(samples[i], (1,2,0)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

def evaluate( val_dataloader, model, n, device ):

    batch_features, batch_labels = next(iter(val_dataloader))
    #Feature batch shape: torch.Size([32, 3, 32, 32])
    originals = batch_features[:n,:]
    reconstructed = model(originals)
    samples = model.sample(n, device)
    return originals, reconstructed, samples

In [None]:
class VanillaVAE(nn.Module):


    def __init__(self,
                 input_size: int,
                 latent_dim: int,
                 in_channels: int = 3,
                 hidden_dims: List = None,
                 beta = 4.0,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim
        self.kld_weight = beta
        self.img_dim = (in_channels, input_size, input_size) # For use by the TensorboardGenerativeModelImageSampler
        self.meta = {}

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        self.final_img = int(input_size / (2**len(hidden_dims))) # Reduce input image size by num of conv layers
        dense_calc = int(self.final_img * self.final_img * hidden_dims[-1])
        # was: hidden_dims[-1]*4

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    #nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(dense_calc, latent_dim)
        self.fc_var = nn.Linear(dense_calc, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, dense_calc)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    #nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, stride=1, padding= 1),
                            nn.Sigmoid())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        if isinstance(input, tuple) or isinstance(input, list):
            input = input[0]
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
# TODO the 512 in view needs to be pulled from the list of hidden laers
        result = result.view(-1, 512, self.final_img, self.final_img)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        self.z = self.reparameterize(mu, log_var)
        return  [self.decode(self.z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0][0]
        input = args[1]
        mu = args[0][2]
        log_var = args[0][3]

        kld_weight = self.kld_weight
        recons_loss =F.mse_loss(recons, input)
        #recons_loss =F.binary_cross_entropy(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + (kld_weight * kld_loss)
        return loss

    def loss_function_exp(self, target, recons, mu, log_var ) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        """

        kld_weight = self.kld_weight
        recons_loss =F.mse_loss(recons, target)
        #recons_loss =F.binary_cross_entropy(recons, target)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + (kld_weight * kld_loss)
        return loss

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [None]:
# Get CIFAR10 data
class AEDataset(torch.utils.data.Dataset):
    """ Convert a dataset intended for categorical output to one that can
        be used to train an autoencoder.
    """
    
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        image, _ = self.dataset[index]

        return image, image

    def __len__(self):
        return len(self.dataset)
    
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        #self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.transform = transforms.Compose([transforms.ToTensor()])
    
    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
        
    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.mnist_train = AEDataset( CIFAR10(self.data_dir, train=True, transform=self.transform) )
            self.mnist_val = AEDataset( CIFAR10(self.data_dir, train=False, transform=self.transform) )
            print( f"Train dataset: {len(self.mnist_train)}" )
            print( f"Val dataset: {len(self.mnist_val)}" )
        
        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
        if stage == "predict":
            self.mnist_predict = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size)
        
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)
            
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_test, batch_size=self.batch_size)
          
    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_predict, batch_size=self.batch_size)


In [None]:
epochs=30
lr=4.7e-4
z_dim=128
beta=0.00001
image_size=32
batch_size=128

In [None]:
# Fastai version of CIFAR10
path = untar_data(URLs.CIFAR)

In [None]:
cifar10_lit = CIFAR10DataModule(batch_size=batch_size)

In [None]:
cifar10_lit.prepare_data()
cifar10_lit.setup("fit")

In [None]:
fastai_vae = VanillaVAE(input_size=image_size, latent_dim=z_dim, beta=beta)
fastai_data = DataLoaders(cifar10_lit.train_dataloader(),
                          cifar10_lit.val_dataloader())
#callbacks = [EarlyStoppingCallback(monitor='valid_loss', min_delta=0.0, patience=5)]
callbacks = []
learn = Learner(fastai_data, fastai_vae, loss_func=fastai_vae.loss_function)
#learn.fit_one_cycle(epochs, lr, cbs=callbacks)
print( f"{learn.opt}" )
if learn.opt is not None:
    print( f"{learn.opt.hypers}" )
print( f"{learn.cbs}" )
learn.show_training_loop()

In [None]:
vae = fastai_vae
device = torch.device("cpu")
vae.to(device)
o,r,s = evaluate( cifar10_lit.val_dataloader(), vae, 10, device )
r = r[0]
o = o.detach().numpy()
r = r.detach().numpy()
s = s.detach().numpy()
show_results( o, r, s )

In [None]:
class LitVAE(pl.LightningModule):
    def __init__(self, lr:float=1e-3, image_size: int=128, latent_dim: int=128, beta:float =4.0, notes: str = None):
        super().__init__()
        self.model = VanillaVAE(input_size=image_size, latent_dim=latent_dim, beta=beta)
        self.lr = lr
        self.latent_dim = latent_dim
        self.img_dim = self.model.img_dim
        self.printed = False
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        #optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        return self.model(x)

    def decode(self, z):
        # For use by TensorboardGenerativeModelImageSampler
        return self.model.decode(z)

    def _run_one_batch(self, batch, batch_idx):

        recons, _, mu, log_var = self.model.forward(batch)

        if isinstance(batch, tuple) or isinstance(batch, list):
            batch = batch[0]        

        try:
            loss_vae = self.model.loss_function_exp( batch, recons, mu, log_var )
        except RuntimeError as ex:
            raise

        return recons, loss_vae

    def training_step(self, batch, batch_idx):
        outputs, train_loss = self._run_one_batch(batch, batch_idx)
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):
        outputs, val_loss = self._run_one_batch(batch, batch_idx)
        self.log("val_loss", val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        outputs, test_loss = self._run_one_batch(batch, batch_idx)
        self.log("test_loss", test_loss)


In [None]:
lit_vae = LitVAE(image_size=image_size, lr=lr, beta=beta)
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None
)

trainer.fit(lit_vae, cifar10_lit)

In [None]:
o,r,s = evaluate( cifar10_lit.val_dataloader(), lit_vae.model, 10, device )
r = r[0]
o = o.detach().numpy()
r = r.detach().numpy()
s = s.detach().numpy()
show_results( o, r, s )

|Hparam|Value|Fastai|Lit|
|-----|-|------|-----|
|Batch Size|128|Good|Bad|
|SGD|||
|Adam||Good|Good|