<a href="https://colab.research.google.com/github/LuhanMikaelson/ARENA_3.0/blob/main/ARENA_AE_%26_VAEs_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

These were my implementations of AE and VAE as a part of ARENA 3.0 curriculum. A portion of this note book was boiler plate code provided.  

## Setup (don't read, just run!)


In [57]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install wandb
    %pip install torchinfo
    %pip install gdown
    %pip install datasets

    # Code to make sure output widgets display
    from google.colab import output
    output.enable_custom_widget_manager()

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter0_fundamentals"):
        !wget https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_3.0-main/chapter0_fundamentals/exercises/*'
        os.remove("/content/main.zip")
        os.rename("ARENA_3.0-main/chapter0_fundamentals", "chapter0_fundamentals")
        os.rmdir("ARENA_3.0-main")
        sys.path.insert(0, "chapter0_fundamentals/exercises")

    # Clear output
    from IPython.display import clear_output
    clear_output()
    print("Imports & installations complete!")

else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Imports & installations complete!


In [58]:
import sys
import torch as t
from torch import nn, optim
import einops
from einops.layers.torch import Rearrange
from tqdm import tqdm
from datasets import load_dataset
from dataclasses import dataclass, field
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
from typing import Optional, Tuple, List, Literal, Union
import plotly.express as px
import torchinfo
import time
import wandb
from PIL import Image
import pandas as pd
from pathlib import Path

# Get file paths to this set of exercises
exercises_dir = Path("chapter0_fundamentals/exercises")
section_dir = exercises_dir / "part5_gans_and_vaes"

from plotly_utils import imshow
from part2_cnns.utils import print_param_count
import part5_gans_and_vaes.tests as tests
import part5_gans_and_vaes.solutions as solutions

from part2_cnns.solutions import (
    Linear,
    ReLU,
    Sequential,
    BatchNorm2d,
)
from part2_cnns.solutions_bonus import (
    pad1d,
    pad2d,
    conv1d_minimal,
    conv2d_minimal,
    Conv2d,
    Pair,
    IntOrPair,
    force_pair,
)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

<details>
<summary>Help - I get a NumPy-related error</summary>

This is an annoying colab-related issue which I haven't been able to find a satisfying fix for. If you restart runtime (but don't delete runtime), and run just the imports cell above again (but not the `%pip install` cell), the problem should go away.
</details>


In [59]:
class ConvTranspose2d(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, kernel_size: IntOrPair, stride: IntOrPair = 1, padding: IntOrPair = 0
    ):
        '''
        Same as torch.nn.ConvTranspose2d with bias=False.
        '''
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        kernel_size = force_pair(kernel_size)
        sf = 1 / (self.out_channels * kernel_size[0] * kernel_size[1]) ** 0.5

        self.weight = nn.Parameter(sf * (2 * t.rand(in_channels, out_channels, *kernel_size) - 1))

    def forward(self, x: t.Tensor) -> t.Tensor:

        return solutions.conv_transpose2d(x, self.weight, self.stride, self.padding)

    def extra_repr(self) -> str:
        return ", ".join([
            f"{key}={getattr(self, key)}"
            for key in ["in_channels", "out_channels", "kernel_size", "stride", "padding"]
        ])

[These visualisations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) (linked in the reading material) may also help build intuition for the transposed convolution module.

In [60]:
class Tanh(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return ((t.exp(x) - t.exp(-x))/(t.exp(x) + t.exp(-x)))

tests.test_Tanh(Tanh)

All tests in `test_Tanh` passed.


In [61]:
class LeakyReLU(nn.Module):
    def __init__(self, negative_slope: float = 0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x: t.Tensor) -> t.Tensor:
        return t.max(t.tensor(0),x) + self.negative_slope * t.min(t.tensor(0),x)

    def extra_repr(self) -> str:
        return f"LeakyReLU negative_slope: {self.negative_slope}"

tests.test_LeakyReLU(LeakyReLU)

All tests in `test_LeakyReLU` passed.


In [62]:
class Sigmoid(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        return 1/(1+t.exp(-x))

tests.test_Sigmoid(Sigmoid)

All tests in `test_Sigmoid` passed.


## Loading data


In [63]:
# %pip install datasets
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("nielsr/CelebA-faces")
print("Dataset loaded.")

# Create path to save the data
celeb_data_dir = section_dir / "data" / "celeba" / "img_align_celeba"
if not celeb_data_dir.exists():
    os.makedirs(celeb_data_dir)

    # Iterate over the dataset and save each image
    for idx, item in tqdm(enumerate(dataset["train"]), total=len(dataset["train"]), desc="Saving individual images..."):
        # The image is already a JpegImageFile, so we can directly save it
        item["image"].save(exercises_dir / "part5_gans_and_vaes" / "data" / "celeba" / "img_align_celeba" / f"{idx:06}.jpg")

    print("All images have been saved.")



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



Downloading metadata:   0%|          | 0.00/667 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/462M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/463M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/463M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/202599 [00:00<?, ? examples/s]

Dataset loaded.


Saving individual images...: 100%|██████████| 202599/202599 [03:50<00:00, 877.41it/s] 

All images have been saved.





The first two arguments of `ImageFolder` are `root` (specifying the filepath for the root directory containing the data), and `transform` (which is a transform object that gets applied to each image).

The function below allows you to load in either the Celeb-A or MNIST data.

In [64]:
def get_dataset(dataset: Literal["MNIST", "CELEB"], train: bool = True) -> Dataset:
    assert dataset in ["MNIST", "CELEB"]

    if dataset == "CELEB":
        image_size = 64
        assert train, "CelebA dataset only has a training set"
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        trainset = datasets.ImageFolder(
            root = exercises_dir / "part5_gans_and_vaes" / "data" / "celeba",
            transform = transform
        )

    elif dataset == "MNIST":
        img_size = 28
        transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        trainset = datasets.MNIST(
            root = exercises_dir / "part5_gans_and_vaes" / "data",
            transform = transform,
            download = True,
        )

    return trainset

We've also given you some code for visualising your data. You should run this code to make sure your data is correctly loaded in.

In [65]:
def display_data(x: t.Tensor, nrows: int, title: str):
    '''Displays a batch of data, using plotly.'''
    # Reshape into the right shape for plotting (make it 2D if image is monochrome)
    y = einops.rearrange(x, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=nrows).squeeze()
    # Normalize, in the 0-1 range
    y = (y - y.min()) / (y.max() - y.min())
    # Display data
    imshow(
        y, binary_string=(y.ndim==2), height=50*(nrows+5),
        title=title + f"<br>single input shape = {x[0].shape}"
    )

In [66]:
# Load in MNIST, get first batch from dataloader, and display
trainset_mnist = get_dataset("MNIST")
x = next(iter(DataLoader(trainset_mnist, batch_size=64)))[0]
display_data(x, nrows=8, title="MNIST data")

# Autoencoders & VAEs

## Autoencoders



<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/gan_images/ae-diagram-l.png" width="700">
                


### Autoencoder architecture


<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/gan_images/ae-help-10.png" width="1100">


In [67]:
class Autoencoder(nn.Module):

    def __init__(self, latent_dim_size: int, hidden_dim_size: int):
        super().__init__()
        self.latent_dim_size = latent_dim_size
        self.hidden_dim_size = hidden_dim_size
        middle_channels = t.tensor([1,16,32])
        self.encoder = nn.Sequential(
            Conv2d(middle_channels[0], middle_channels[1], 4, 2, 1),
            ReLU(),
            Conv2d(middle_channels[1], middle_channels[2], 4, 2, 1),
            ReLU(),
            Rearrange('b c h w -> b (c h w)'),
            Linear(32 * 7 * 7, hidden_dim_size),
            ReLU(),
            Linear(self.hidden_dim_size, self.latent_dim_size)

        )
        self.decoder = nn.Sequential(
            Linear(self.latent_dim_size, self.hidden_dim_size),
            ReLU(),
            Linear(self.hidden_dim_size, 32 * 7 * 7),
            Rearrange('b (c h w) -> b c h w', c=32, h= 7, w=7),
            ReLU(),
            ConvTranspose2d(middle_channels[2], middle_channels[1], 4, 2, 1),
            ReLU(),
            ConvTranspose2d(middle_channels[1], middle_channels[0], 4, 2, 1)

        )

    def forward(self, x: t.Tensor) -> t.Tensor:
      x = self.encoder(x)
      x = self.decoder(x)
      return x

In [68]:
soln_Autoencoder = solutions.Autoencoder(latent_dim_size=5, hidden_dim_size=128)
my_Autoencoder = Autoencoder(latent_dim_size=5, hidden_dim_size=128)

print_param_count(my_Autoencoder, soln_Autoencoder)

Model 1, total params = 421413
Model 2, total params = 421413
All parameter counts match!


Unnamed: 0,name_1,shape_1,num_params_1,num_params_2,shape_2,name_2
0,encoder.0.weight,"(16, 1, 4, 4)",256,256,"(16, 1, 4, 4)",encoder.0.weight
1,encoder.2.weight,"(32, 16, 4, 4)",8192,8192,"(32, 16, 4, 4)",encoder.2.weight
2,encoder.5.weight,"(128, 1568)",200704,200704,"(128, 1568)",encoder.5.weight
3,encoder.5.bias,"(128,)",128,128,"(128,)",encoder.5.bias
4,encoder.7.weight,"(5, 128)",640,640,"(5, 128)",encoder.7.weight
5,encoder.7.bias,"(5,)",5,5,"(5,)",encoder.7.bias
6,decoder.0.weight,"(128, 5)",640,640,"(128, 5)",decoder.0.weight
7,decoder.0.bias,"(128,)",128,128,"(128,)",decoder.0.bias
8,decoder.2.weight,"(1568, 128)",200704,200704,"(1568, 128)",decoder.2.weight
9,decoder.2.bias,"(1568,)",1568,1568,"(1568,)",decoder.2.bias


In [69]:
testset = get_dataset("MNIST", train=False)
HOLDOUT_DATA = dict()
for data, target in DataLoader(testset, batch_size=1):
    if target.item() not in HOLDOUT_DATA:
        HOLDOUT_DATA[target.item()] = data.squeeze()
        if len(HOLDOUT_DATA) == 10: break
HOLDOUT_DATA = t.stack([HOLDOUT_DATA[i] for i in range(10)]).to(dtype=t.float, device=device).unsqueeze(1)

In [70]:
@dataclass
class AutoencoderArgs():
    latent_dim_size: int = 5
    hidden_dim_size: int = 128
    dataset: Literal["MNIST", "CELEB"] = "MNIST"
    batch_size: int = 512
    epochs: int = 10
    lr: float = 1e-3
    betas: Tuple[float] = (0.5, 0.999)
    seconds_between_eval: int = 5
    wandb_project: Optional[str] = 'day5-ae-mnist'
    wandb_name: Optional[str] = None

In [71]:
class AutoencoderTrainer:
    def __init__(self, args: AutoencoderArgs):
        self.args = args
        self.trainset = get_dataset(args.dataset)
        self.trainloader = DataLoader(self.trainset, batch_size=args.batch_size, shuffle=True)
        self.model = Autoencoder(
            latent_dim_size = args.latent_dim_size,
            hidden_dim_size = args.hidden_dim_size,
        ).to(device)
        self.optimizer = t.optim.Adam(self.model.parameters(), lr=args.lr, betas=args.betas)
        self.criterion = nn.MSELoss()

    def training_step(self, img: t.Tensor) -> t.Tensor:
        '''
        Performs a training step on the batch of images in `img`. Returns the loss.
        '''
        batch_output = self.model.forward(img)
        loss = self.criterion(batch_output, img)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        return loss


    @t.inference_mode()
    def evaluate(self) -> None:
        '''
        Evaluates model on holdout data, logs to weights & biases.
        '''
        test_cases = self.model(HOLDOUT_DATA).cpu().numpy()
        wandb.log({"images": [wandb.Image(test_case) for test_case in test_cases]}, step=self.step)


    def train(self) -> None:
        '''
        Performs a full training run, logging to wandb.
        '''
        self.step = 0
        last_log_time = time.time()
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name)

        for epoch in range(self.args.epochs):

            progress_bar = tqdm(self.trainloader, total=int(len(self.trainloader)))

            for i, (img, label) in enumerate(progress_bar): # remember that label is not used

                img = img.to(device)
                loss = self.training_step(img)
                wandb.log(dict(loss=loss), step=self.step)

                # Update progress bar
                self.step += img.shape[0]
                progress_bar.set_description(f"{epoch=}, {loss=:.4f}, examples_seen={self.step}")

                # Evaluate model on the same holdout data
                if time.time() - last_log_time > self.args.seconds_between_eval:
                    last_log_time = time.time()
                    self.evaluate()

        wandb.finish()


args = AutoencoderArgs()
trainer = AutoencoderTrainer(args)
trainer.train()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

epoch=0, loss=0.4004, examples_seen=60000: 100%|██████████| 118/118 [10:09<00:00,  5.17s/it]
epoch=1, loss=0.3853, examples_seen=120000: 100%|██████████| 118/118 [10:22<00:00,  5.27s/it]


VBox(children=(Label(value='1.087 MB of 1.087 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▇▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.38526


After ten epochs, you should be able to get output of the following quality:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/gan_images/autoencoder_2.png" width="700">

This is a pretty faithful representation. Note how it's mixing up features for some of the numbers - for instance, the 5 seems to have been partly reproduced as a 9. But overall, it seems pretty accurate!

## Generating images from an autoencoder

We'll now return to the issue we mentioned briefly earlier - how to generate output? This was easy for our GAN; the only way we ever produced output was by putting random noise into the generator. But how should we interpret the latent space between our encoder and decoder?

We can try and plot the outputs produced by the decoder over a range. The code below does this for the model in the solutions (you might have to make some small adjustments depending on exactly how you've implemented your autoencoder).

In [None]:
@t.inference_mode()
def visualise_output(
    model: Autoencoder,
    n_points: int = 11,
    interpolation_range: Tuple[float, float] = (-3, 3),
) -> None:
    '''
    Visualizes the output of the decoder, along the first two latent dims.
    '''
    # Constructing latent dim data by making two of the dimensions vary indep in the interpolation range
    grid_latent = t.zeros(n_points**2, model.latent_dim_size).to(device)
    x = t.linspace(*interpolation_range, n_points).to(device)
    grid_latent[:, 0] = einops.repeat(x, "dim1 -> (dim1 dim2)", dim2=n_points)
    grid_latent[:, 1] = einops.repeat(x, "dim2 -> (dim1 dim2)", dim1=n_points)

    # Pass through decoder
    output = model.decoder(grid_latent).cpu().numpy()

    # Normalize & truncate, then unflatten back into a grid shape
    output_truncated = np.clip((output * 0.3081) + 0.1307, 0, 1)
    output_single_image = einops.rearrange(
        output_truncated,
        "(dim1 dim2) 1 height width -> (dim1 height) (dim2 width)",
        dim1=n_points
    )

    # Display the results
    px.imshow(
        output_single_image,
        color_continuous_scale="greys_r",
        title="Decoder output from varying first principal components of latent space"
    ).update_layout(
        xaxis=dict(title_text="dim1", tickmode="array", tickvals=list(range(14, 14+28*n_points, 28)), ticktext=[f"{i:.2f}" for i in x]),
        yaxis=dict(title_text="dim2", tickmode="array", tickvals=list(range(14, 14+28*n_points, 28)), ticktext=[f"{i:.2f}" for i in x])
    ).show()


visualise_output(trainer.model)

### Reparameterisation trick



<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/gan_images/vae-reparam-l.png" width="800">



## Building a VAE



<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/gan_images/ae-before-after.png" width="750">

In [None]:
class VAE(nn.Module):
    encoder: nn.Module
    decoder: nn.Module

    def __init__(self, latent_dim_size: int, hidden_dim_size: int):
        super().__init__()
        self.latent_dim_size = latent_dim_size
        self.hidden_dim_size = hidden_dim_size
        middle_channels = t.tensor([1,16,32])
        self.encoder = nn.Sequential(
            Conv2d(middle_channels[0], middle_channels[1], 4, 2, 1),
            ReLU(),
            Conv2d(middle_channels[1], middle_channels[2], 4, 2, 1),
            ReLU(),
            Rearrange('b c h w -> b (c h w)'),
            Linear(32 * 7 * 7, hidden_dim_size),
            ReLU(),
            Linear(self.hidden_dim_size, self.latent_dim_size * 2)

        )
        self.decoder = nn.Sequential(
            Linear(self.latent_dim_size, self.hidden_dim_size),
            ReLU(),
            Linear(self.hidden_dim_size, 32 * 7 * 7),
            Rearrange('b (c h w) -> b c h w', c=32, h= 7, w=7),
            ReLU(),
            ConvTranspose2d(middle_channels[2], middle_channels[1], 4, 2, 1),
            ReLU(),
            ConvTranspose2d(middle_channels[1], middle_channels[0], 4, 2, 1)

        )


    def sample_latent_vector(self, x: t.Tensor) -> Tuple[t.Tensor, t.Tensor, t.Tensor]:
        '''
        Passes `x` through the encoder. Returns the mean and log std dev of the latent vector,
        as well as the latent vector itself. This function can be used in `forward`, but also
        used on its own to generate samples for evaluation.
        '''
        raw_output = self.encoder(x)
        raw_output = einops.rearrange(raw_output, 'b (distribution hidden_dim) -> b distribution hidden_dim', distribution = 2)
        samples = t.diag(t.randn(self.latent_dim_size)).to(device)
        logsigma = raw_output[:,0]
        sigma = logsigma.exp()
        mu = raw_output[:,1]
        z = mu + t.mm(sigma, samples)
        return (z, mu, logsigma)

    def forward(self, x: t.Tensor) -> Tuple[t.Tensor, t.Tensor, t.Tensor]:
        '''
        Passes `x` through the encoder and decoder. Returns the reconstructed input, as well
        as mu and logsigma.
        '''
        z , mu, logsigma = self.sample_latent_vector(x)
        x_prime = self.decoder(z)
        return (x_prime, mu, logsigma)

In [None]:
trainset = get_dataset("MNIST")
model = VAE(latent_dim_size=5, hidden_dim_size=100).to(device)

print(torchinfo.summary(model, input_data=trainset[0][0].unsqueeze(0).to(device)))

Layer (type:depth-idx)                   Output Shape              Param #
VAE                                      [5, 1, 28, 28]            --
├─Sequential: 1-1                        [1, 10]                   --
│    └─Conv2d: 2-1                       [1, 16, 14, 14]           256
│    └─ReLU: 2-2                         [1, 16, 14, 14]           --
│    └─Conv2d: 2-3                       [1, 32, 7, 7]             8,192
│    └─ReLU: 2-4                         [1, 32, 7, 7]             --
│    └─Rearrange: 2-5                    [1, 1568]                 --
│    └─Linear: 2-6                       [1, 100]                  156,900
│    └─ReLU: 2-7                         [1, 100]                  --
│    └─Linear: 2-8                       [1, 10]                   1,010
├─Sequential: 1-2                        [5, 1, 28, 28]            --
│    └─Linear: 2-9                       [5, 100]                  600
│    └─ReLU: 2-10                        [5, 100]                  --
│ 

You can also do the previous thing (compare your architecture to the solution), but this might be less informative if your model doesn't implement the 2-variables approach in the same way as the solution does.


## New loss function

### Beta-VAEs

The Beta-VAE is a very simple extension of the VAE, with a different loss function: we multiply the KL Divergence term by a constant $\beta$. This helps us balance the two different loss terms. For instance, I found using $\beta = 0.1$ gave better results than the default $\beta = 1$.

In [None]:
@dataclass
class VAEArgs(AutoencoderArgs):
    wandb_project: Optional[str] = 'day5-vae-mnist'
    beta_kl: float = 0.1


class VAETrainer:
    def __init__(self, args: VAEArgs):
        self.args = args
        self.trainset = get_dataset(args.dataset)
        self.trainloader = DataLoader(self.trainset, batch_size=args.batch_size, shuffle=True)
        self.model = VAE(
            latent_dim_size = args.latent_dim_size,
            hidden_dim_size = args.hidden_dim_size,
        ).to(device)
        self.optimizer = t.optim.Adam(self.model.parameters(), lr=args.lr, betas=args.betas)
        self.criterion = nn.MSELoss()


    def training_step(self, img: t.Tensor) -> t.Tensor:
        '''
        Performs a training step on the batch of images in `img`. Returns the loss.
        '''
        img_prime, mu, logsigma = self.model.forward(img)
        construction_loss = self.criterion(img_prime, img)
        #KL_divergence =  (logsigma.exp() ** 2 + mu ** 2 - 1)/2 - logsigma
        kl_div_loss = (0.5 * (mu ** 2 + t.exp(2 * logsigma) - 1) - logsigma).mean() * args.beta_kl
        loss = construction_loss +  kl_div_loss
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        return loss

    @t.inference_mode()
    def evaluate(self) -> None:
        '''
        Evaluates model on holdout data, logs to weights & biases.
        '''
        test_cases = self.model(HOLDOUT_DATA)[0].cpu().numpy()
        wandb.log({"images": [wandb.Image(test_case) for test_case in test_cases]}, step=self.step)

    def train(self) -> None:
        '''
        Performs a full training run, logging to wandb.
        '''
        self.step = 0
        last_log_time = time.time()
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name)

        for epoch in range(self.args.epochs):

            progress_bar = tqdm(self.trainloader, total=int(len(self.trainloader)))

            for i, (img, label) in enumerate(progress_bar): # remember that label is not used

                img = img.to(device)
                loss = self.training_step(img)
                wandb.log(dict(loss=loss), step=self.step)

                # Update progress bar
                self.step += img.shape[0]
                progress_bar.set_description(f"{epoch=}, {loss=:.4f}, examples_seen={self.step}")

                # Evaluate model on the same holdout data
                if time.time() - last_log_time > self.args.seconds_between_eval:
                    last_log_time = time.time()
                    self.evaluate()

        wandb.finish()


args = VAEArgs(latent_dim_size=10, hidden_dim_size=100)
trainer = VAETrainer(args)
trainer.train()

VBox(children=(Label(value='0.002 MB of 0.010 MB uploaded\r'), FloatProgress(value=0.15143531976744187, max=1.…

0,1
loss,▁

0,1
loss,0.99538


epoch=0, loss=0.7012, examples_seen=60000: 100%|██████████| 118/118 [10:01<00:00,  5.10s/it]
epoch=1, loss=0.6547, examples_seen=120000: 100%|██████████| 118/118 [10:16<00:00,  5.22s/it]
epoch=2, loss=0.5775, examples_seen=180000: 100%|██████████| 118/118 [10:21<00:00,  5.27s/it]
epoch=3, loss=0.5622, examples_seen=240000: 100%|██████████| 118/118 [10:26<00:00,  5.31s/it]
epoch=4, loss=0.5790, examples_seen=300000: 100%|██████████| 118/118 [10:21<00:00,  5.27s/it]
epoch=5, loss=0.5125, examples_seen=360000: 100%|██████████| 118/118 [10:19<00:00,  5.25s/it]
epoch=6, loss=0.4691, examples_seen=420000: 100%|██████████| 118/118 [10:23<00:00,  5.29s/it]
epoch=7, loss=0.4075, examples_seen=480000: 100%|██████████| 118/118 [10:18<00:00,  5.24s/it]
epoch=8, loss=0.4810, examples_seen=540000: 100%|██████████| 118/118 [10:15<00:00,  5.21s/it]
epoch=9, loss=0.4837, examples_seen=600000: 100%|██████████| 118/118 [10:12<00:00,  5.19s/it]


VBox(children=(Label(value='4.970 MB of 4.970 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▆▆▅▅▅▆▅▅▄▄▃▄▃▃▃▂▃▂▃▂▂▃▂▃▃▂▂▁▁▂▂▂▁▁▁▁▂▁▁

0,1
loss,0.48367


In [None]:
#t.save(model.state_dict(), 'vae_model_weights.pth')

In [None]:
@t.inference_mode()
def visualise_output(
    model: VAE,
    n_points: int = 11,
    interpolation_range: Tuple[float, float] = (-3, 3),
) -> None:
    '''
    Visualizes the output of the decoder, along the first two latent dims.
    '''
    # Constructing latent dim data by making two of the dimensions vary indep in the interpolation range
    grid_latent = t.zeros(n_points**2, model.latent_dim_size).to(device)
    x = t.linspace(*interpolation_range, n_points).to(device)
    grid_latent[:, 0] = einops.repeat(x, "dim1 -> (dim1 dim2)", dim2=n_points)
    grid_latent[:, 1] = einops.repeat(x, "dim2 -> (dim1 dim2)", dim1=n_points)

    # Pass through decoder
    output = model.decoder(grid_latent).cpu().numpy()

    # Normalize & truncate, then unflatten back into a grid shape
    output_truncated = np.clip((output * 0.3081) + 0.1307, 0, 1)
    output_single_image = einops.rearrange(
        output_truncated,
        "(dim1 dim2) 1 height width -> (dim1 height) (dim2 width)",
        dim1=n_points
    )

    # Display the results
    px.imshow(
        output_single_image,
        color_continuous_scale="greys_r",
        title="Decoder output from varying first principal components of latent space"
    ).update_layout(
        xaxis=dict(title_text="dim1", tickmode="array", tickvals=list(range(14, 14+28*n_points, 28)), ticktext=[f"{i:.2f}" for i in x]),
        yaxis=dict(title_text="dim2", tickmode="array", tickvals=list(range(14, 14+28*n_points, 28)), ticktext=[f"{i:.2f}" for i in x])
    ).show()


visualise_output(trainer.model)