# CNN Playbook

This notebook will introduce CNN building atop of DNN nodes and will examine how they work.

Then after doing this we will build an auto-encoder for the mnist numerical dataset and examine the different aspects of this dataset up-close.

## Part 1 2D Sobel filters

This section is about building the 2D Sobel filters and applying them to an input image

## Part 2 PyTorch 2D CNN filters

This section covers building 

## Part 3 Building a VAE model

This section is about Building a VAE model using the PyTorch API

## Part 4 Training our VAE with Test/Validate

This section is about Training the VAE model, again using PyTorch, but now with a validation dataset

## Part 5 Load/Save model

This section introduces the load/save model

## Part 6 Examining the VAE Latent-Space

This section is about taking the trained VAE model and examining the 

## Bonus 7 What is required to train over the CIFAR10 dataset

If you have completed building a VAE model which can be trained over the B&W mnist numerical dataset.
A common "next step up" is to use the CIFAR10 or CIFAR100 which contain 3x32x32 rgb images.

The approach is the same as before, but what level of accuracy do you think you can achieve by training on this dataset?

## Marking

You will get marks for completeing the different tasks within this notebook:

Any code expected for you to complete will contain `## FINISH_ME ##` indicating the code isn't expected to run until you have completed it.

I would recommend tackling the playbook in order from Part1 -> Part2 -> Part3 -> Part4.


| <p align='left'> Title                         | <p align='left'> Parts | <p align='left'> Number of marks |
| ------------------------------------- | ----- | --- |
| <p align='left'> 1. Construct 2D Sobel Filters and apply                   | <p align='left'>  2  | <p align='left'> 1 |
| <p align='left'> 2. Constructing different sized 2D filters and examine output | <p align='left'>  1  | <p align='left'> 1 |
| <p align='left'> 3. Building the VAE model                                 | <p align='left'>  3  | <p align='left'> 3 |
| <p align='left'> 4. Train the VAE model                                    | <p align='left'>  2  | <p align='left'> 2 |
| <p align='left'> 5. Load/Save a model to disk                              | <p align='left'>  1  | <p align='left'> 1 |
| <p align='left'> 6. Examine the trained VAE latent space                   | <p align='left'>  1  | <p align='left'> 2 |
| <p align='left'> **Bonus 1:** Training a VAE over the CIFAR10 dataset      | <p align='left'>  4  | <p align='left'> 1 |
| <p align='left'> **Total** | | <p align='left'> max **10** |

# Part 0

Load the requirements to run the notebook

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import random
import itertools
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

In [None]:
# Reproducibility in Science is critial, in computing it's often just a convenience
_FIXED_SEED=12345
random.seed(_FIXED_SEED)
np.random.seed(_FIXED_SEED)

# This is a connection of globals needed to make everything re-producible
torch.manual_seed(_FIXED_SEED)  # PyTorch CPU

# Ensure reproducibility on Metal (MPS)
if torch.backends.mps.is_available():
    torch.mps.manual_seed(_FIXED_SEED)  # Fix seed for MPS backend

if torch.cuda.is_available():
    torch.cuda.manual_seed(_FIXED_SEED)  # PyTorch GPU (if used)
    torch.cuda.manual_seed_all(_FIXED_SEED)  # If using multi-GPU

    # Ensure deterministic behavior in CUDA operations (if available)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # Disable auto-tuner for determinism

In [None]:
# Load the image using matplotlib's imread
image_path = 'img.png'
image = plt.imread(image_path)  # Reads image as (H, W, C) or (H, W) if grayscale

In [None]:
# These aren't needed until later but it's common/very-good practice to define globals at the top of any file/playbook

# Hyperparameters
batch_size = 128
latent_dim = 10
epochs = 30
learning_rate = 1e-4

# Part 1 Sobel Filters

This short section will walk through using Sobel filters to perform "edge detection" on an input image.

For this you will need to complete the Sobel filters themselves adnd then apply them to the input image that has been provided and analyze the output.

## 1.1 Lets analyze our input image and convert to grayscale

In [None]:
# First plot our input image for comparison.
plt.figure(figsize=(12, 4))
plt.title('Original Image')
plt.imshow(image.squeeze())
plt.axis('off')
plt.show()

In [None]:
# Convert to grayscale if it's an RGB image
if image.ndim == 3:
    # Simple average method for grayscale conversion
    image = image.mean(axis=2)  # Shape: (H, W)

In [None]:
# Convert to a PyTorch tensor and normalize to [0, 1] if needed
# Some images may already be normalized (if float type), so we check
if image.max() > 1:
    image = image / 255.0

## 1.2 Construct the Sobel filters

```
The kernels Gx and Gy as covered in the last lecture:
      _               _                   _                _
     |                 |                 |                  |
     | 1.0   0.0  -1.0 |                 |  1.0   2.0   1.0 |
Gx = | 2.0   0.0  -2.0 |    and     Gy = |  0.0   0.0   0.0 |
     | 1.0   0.0  -1.0 |                 | -1.0  -2.0  -1.0 |
     |_               _|                 |_                _|
```

In [None]:
# Add batch and channel dimensions (1, 1, H, W)
image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

Gx = [ [,,], [], [] ] ## FINISH_ME ##
Gy = [ [,,], [], [] ] ## FINISH_ME ##

# Define Sobel kernels
sobel_kernel_x = torch.tensor(Gx, dtype=torch.float32).view(1, 1, 3, 3)

sobel_kernel_y = torch.tensor(Gy, dtype=torch.float32).view(1, 1, 3, 3)

## 1.3 Apply Sobel and plot the result

NB:
When plotting the result of applying a filter it's always best to plot the RMS of the data.

In [None]:
# Apply Sobel filters using convolution
edges_x = F.conv2d(image_tensor, sobel_kernel_x, padding=1)
edges_y = F.conv2d(image_tensor, sobel_kernel_y, padding=1)

In [None]:
# Combine edges to get the gradient magnitude
edges = torch.sqrt(edges_x**2 + edges_y**2)

In [None]:
# Plot the original and edge-detected images
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.title('Original Image')
plt.imshow(image_tensor.squeeze().numpy(), cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Sobel X (Horizontal Edges)')
plt.imshow( ## FINISH_ME ##
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Sobel Y (Vertical Edges)')
plt.imshow( ## FINISH_ME ##
plt.axis('off')

plt.tight_layout()
plt.show()

# Part 2 Understanding the PyTorch 2D CNN

First we'll load the mnist numerical dataset, construct some 2D CNN and see if we understand the results.

##  Part 2.1 Load the mnist dataset, this-time using transformers

In [None]:
# Data Preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# Load full training MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# Load Test Set
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Split into 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

## Part 2.2 Create some Conv 2D CNN and apply them to some example inputs

In [None]:
random_idx = random.randint(0, len(train_dataset) - 1)  # Select a random image
image, label = train_dataset[random_idx]  # Random MNIST image

# Add batch and channel dimensions for Conv2d (1, 1, 28, 28)
input_image = image.unsqueeze(0)  # Shape: (1, 1, 28, 28)

In [None]:
# Function to Apply Convolution with Different Configs
def apply_conv(input_image, kernel_size, stride, padding):
    # We want to use in_channels=1 and out_channels=1 and various inputs as defined above
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    conv = nn.Conv2d(in_channels=1, out_channels=1, ## FINISH_ME ##
    output = conv(input_image)
    return output

# Different Configurations
configs = [
    {"kernel_size": 3, "stride": 1, "padding": 0},
    {"kernel_size": 3, "stride": 1, "padding": 2},
    {"kernel_size": 5, "stride": 1, "padding": 0},
    {"kernel_size": 5, "stride": 2, "padding": 2},
    {"kernel_size": 3, "stride": 2, "padding": 1}
]

# Plot Input Image
plt.figure(figsize=(10, 6))
plt.subplot(2, 3, 1)
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Original Image\n(28x28)")
plt.axis('off')

# Apply Configurations and Plot Outputs
for i, cfg in enumerate(configs):

    output = apply_conv(input_image, cfg['kernel_size'], cfg['stride'], cfg['padding'])
    output_shape = output.shape  # Shape: (1, 1, H_out, W_out)

    plt.subplot(2, 3, i+2)
    plt.imshow( ## FINISH_ME ##
    plt.title(f"Kernel: {cfg['kernel_size']}, Stride: {cfg['stride']}, Pad: {cfg['padding']}\nShape: {output_shape[2]}x{output_shape[3]}")
    plt.axis('off')

plt.tight_layout()
plt.show()

# Part 3 Construct our data loaders and Variational Auto-Encoder model

Our VAE model is designed to encode information from our 

## Part 3.1 Construct our data loaders

In [None]:
# Create DataLoaders for train and validate
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Part 3.2 Construct our Variational Auto-Encoder class

Our Auto-Encoder class has several key features which we want to demonstrate.

1. Decreasing dimension when Encoding information into latent-space.
2. Increasing dimension when Decoding from latent-space.
3. Symmetry between our encoder and decoder to help model training.
4. Fixed Latent-Space dimension
5. Weights initialized using 'sensible' defaults
6. Use of the 'SiLU' https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html activation function

Our encoder needs to output the value of the input image encoded into this latent-space.


### 3.2.0 The re-parameterization trick is needed to make the Auto-Encoder work but is only used when assembling the full model

In [None]:
vae_base_model_dim = 16

In [None]:
# Reparameterization Trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

### 3.2.1 Build our Encoder

We want to start from 1x28x28 and apply a filter which takes us to 16x28x28 then decreasing to 32x14x14 then 64x7x7 then finally we take this output and encode the final value onto our latent-space parameters via mu/logvar.

It's common to use kernels of size 3 when sampling but not up/down-scaling and more common to use larger filters such as 4 when up-down-sampling to capture more information.

With all that in mind you should get something similar to:

1. input=1, output=16, kernel=3, stride=1, padding=1
2. input=16, output=32, kernel=4, stride=2, padding=1
3. input=32, output=64, kernel=4, stride=2, passing=1
4. 64x7x7 -> 128

In [None]:
# CNN Encoder with Swish (SiLU) and Batch Normalization
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        # conv0 and bn0 are used by layer1
        self.conv0 = nn.Conv2d( ## FINISH_ME )
        self.bn0 = nn.BatchNorm2d(vae_base_model_dim)

        # conv1 and bn1 are used by layer2
        self.conv1 = nn.Conv2d( ## FINISH_ME )
        self.bn1 = nn.BatchNorm2d(vae_base_model_dim*2)

        # conv2 and bn2 are used by layer 3
        self.conv2 = nn.Conv2d( ## FINISH_ME ##
        self.bn2 = nn.BatchNorm2d(vae_base_model_dim*4)

        # fc1 'projects' from filtered data to Latent-Space
        self.fc1 = nn.Linear(vae_base_model_dim*4 * 7 * 7, 128)

        # fc_my and fc_logvar are needed to use the re-param trick later
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

        # Lets make sure the model is initialized better than random
        self._initialize_weights()

    def forward(self, x):
        # Layer 1 we're not down-sampling we're performing feature extraction
        x = F.silu(self.conv0(x))
        # Layer 2 we want to down-sample reducing data to next layer by 50%
        x = F.silu(self.bn1(self.conv1(x)))
        # Layer 3 we're down-sampling again
        x = ## FINISH_ME ##

        # This is the equivalent of re-sizing the data flowing through the model
        # This preserves the dim x[0] which is needed to preserve the batch-structure
        x = x.view(x.size(0), -1)

        # This now projects the flattened data to 128-dim
        x = F.silu(self.fc1(x))

        # This finally projects down from 128 -> latent-dim
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def _initialize_weights(self):
        # This iterates through all PyTorch objects created within this class
        ## The mechanism for this is advanced but is made available by Python 'magic'
        for m in self.modules():
            # If we find a Conv2D class within our module, make sure we initialize this better
            # https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # If we find a Linear class lets initialize the parameters using xavier
            # https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

### 3.2.2 Build our Decoder

Starting from our latent-space we first need to project up to 64x7x7 then using filters, up-scale to 32x14x14, 16x28x28 then finally using an additional filter downscale back to 1x28x28.

We want to make our decoder symmetrical to our encoder but in reverse. This helps with model stability during training.


In [None]:
# DNN Decoder with LeakyReLU and Batch Normalization
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()

        # fc1 & bn1 projects up from the Latent-Space to 128-dim
        self.fc1 = nn.Linear(latent_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)

        # fc2 & bn2 projects from 128-dim to model input
        self.fc2 = nn.Linear(128, vae_base_model_dim*4 * 7 * 7)
        self.bn2 = nn.BatchNorm1d(vae_base_model_dim*4 * 7 * 7)

        # deconv0 & bn3 up-sample from Latent-Space dist
        # We use ConvTranspose2d vs Conv2D 
        self.deconv0 = nn.ConvTranspose2d( ## FINISH_ME ## )
        self.bn3 = nn.BatchNorm2d(vae_base_model_dim*2)

        # deconv1 and bn4 up-sample from Latent-Space further
        self.deconv1 = nn.ConvTranspose2d( ## FINISH_ME ## )
        self.bn4 = nn.BatchNorm2d(vae_base_model_dim)

        # deonv2 peforms the opposite of feature extraction, identifying key features from the up-scaling
        self.deconv2 = nn.ConvTranspose2d(vae_base_model_dim, 1, kernel_size=3, stride=1, padding=1)

        # Lets make sure the model is initialized better than random
        self._initialize_weights()

    def forward(self, z):
        # Lets project from Latent-Space to 128-dim
        x = F.silu(self.bn1(self.fc1(z)))
        # Lets project from 128-dim to up-scale filters
        x = F.silu(self.bn2(self.fc2(x)))

        # This is the equivalent to the opposite of the 'flatten' reize in the encoder
        x = x.view(x.size(0), vae_base_model_dim*4, 7, 7)

        # This up-scales from the Latent-Space projection to a larger image
        x = F.silu(self.bn3(self.deconv0(x)))
        # This up-scales again to give us output image sized data-streams
        x = ## FINISH_ME ##

        # This is the final 'projection' layer which extracts key features and makes an output image 
        x = torch.sigmoid(self.deconv2(x))
        return x

    def _initialize_weights(self):
        # This iterates through all PyTorch objects created within this class
        ## The mechanism for this is advanced but is made available by Python 'magic'
        for m in self.modules():
            # If we find a Conv2D class within our module, make sure we initialize this better
            # https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # If we find a Linear class lets initialize the parameters using xavier
            # https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

### 3.2.3 Build the final VAE model

The VAE model itself is quite short. We just want to programatically 'connect' the Encoder and Decoder graphs through their latent-space and return the outputs needed by our loss function(s).

In [None]:
# Variational Autoencoder (VAE)
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Our VAE needs an Encoder and Decoder
        # Both need to be constructed with knowledge of the required latent-dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        # The encoder returns mu and sigma(logvar) of our distribtuion in LS
        mu, logvar = ## FINISH_ME ##

        # We now want to re-parameterize to a single vector in LS
        z = reparameterize(mu, logvar)

        # The decoder is able to now take the single vector to re-construct an image
        x_recon = ## FINISH_ME ##
        # To train the model we need to return the final value
        # AND intermediate values from the construction of our LS
        return x_recon, mu, logvar

## 3.3 Defining the Model Loss

The total training loss from this model is the linear combination of the KL-divergence and the Reconstruction loss of the images themselves.

The Reconstruction loss in this case is simply defined as the F.mse_loss https://pytorch.org/docs/stable/generated/torch.nn.functional.mse_loss.html

In [None]:
# Loss Function: Reconstruction Loss + KL Divergence
def loss_function(x_recon, x, mu, logvar):
    # The loss function relies on Model Output, Truth, mu and sigma

    # First part of the loss is simply how 'bad' our output images are compared to 'truth'
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')

    # We want to make sure our Latent-Space is encoded down to a 'Probability Space'
    # This means we want to calculate the kl_loss of our LS-vector distributioncompared to a sampled normalized probaility
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + kl_loss, recon_loss

# Part 4 Training our model

As with our Classifier model we need to train our model to do something useful and extract information from the dataset

## Part 4.1 training pre-requisits

In [None]:
# Training Loop
device = torch.device("cuda") # As with last week if you have access to anything 'non-cpu' I recommend using it here(!)

# To train our model we need to first construct it, this means we can make a decision here on the latent-dim of the model
vae = VAE(latent_dim).to(device)

# To train the model we constructed we need to let the optimizer know about it
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

In [None]:
# We want to track the training loss and the loss from our validation dataset
train_losses = []
val_losses = []

## Part 4.2 Our training Loop

It's good to prove that your model trains as expected. However the computational power required to train a proper model for this problem is quite expensive. (30min on a decent GPU or more!)

With that in mind, feel free to only train over 2-3 epochs to demonstrate that your model is indeed reducing in loss per-epoch and compare the raw train and validation losses in the next step

In [None]:

# Loop over all epochs
for epoch in range(epochs):

    # Make sure the model is in training mode
    vae.train()

    # (re-) set the training loss to be 0 for each epoch
    train_loss = 0

    # Loop over all batches in the train_loader
    i=0
    for x, _ in train_loader:

        # Here we don't _need_ to know about the data labels
        # In Python, it's common/good-practice to allocate
        # returned parameters we don't care about to '_'

        # Our training loop is no different to when training
        # a classifier
        x = x.to(device)
        optimizer.zero_grad()
        x_recon, mu, logvar = vae(x)

        # Here we are calculating our training loss
        # This depends on comparing:
        # model output to input,   x_recon to x
        # LS dist to Prob-Space,   mu&logvar to n-dim normal
        loss, _ = loss_function( ## FINISH_ME ##
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        ## Useful for debugging, noisy for real training
        ##print(f'Processing Batch: {i} of {len(train_loader)}')
        ##i+=1

    ## This section is new
    ## We want to _evaluate_ our dataset once per epoch
    ## This allows us to see if a model has over-trained or not

    # Validation Loss
    vae.eval()
    val_loss = 0
    with torch.no_grad():

        # As above, iterate over all data in the dataset
        # Here we're using the Validation dataset
        for x, _ in val_loader:
            x = x.to(device)

            x_recon, mu, logvar = vae(x)
            loss, _ = loss_function( ## FINISH_ME ##
            val_loss += loss.item()

            # We DO NOT TRAIN over the VAE dataset so return from here

    # We can now calculate the loss per-batch for both data (sub-)sets
    avg_train_loss = train_loss / len(train_loader.dataset)
    avg_val_loss = ## FINISH_ME ##
    # Store the values to examine them afterwards
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

## Part 4.3 Plot the losses from our training

This allows us to make a statement about whether our model has over-trained or under-trained.

In [None]:
# Plot Training vs Validation Loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss", marker='o')
plt.plot(range(1, len(val_losses)+1), val_losses, label="Validation Loss", marker='s')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train vs Validation Loss")
plt.legend()
plt.yscale("log")
#plt.xscale("log")
plt.grid(True)
plt.show()

## Part 4.4 Examine the model performance

We want to visualize what the output from our model is based on inputs.

If you prefer you can run this after loading a pre-saved model to see how it performs

In [None]:
# Show Input vs Reconstructed Images
vae.eval()
with torch.no_grad():
    sample_data, _ = next(iter(train_loader))
    sample_data = sample_data[:8].to(device)  # Select 8 images
    reconstructed, _, _ = vae(sample_data)

# Convert to CPU for visualization
sample_data = sample_data.cpu()
reconstructed = reconstructed.cpu()

# Plot Input vs Output
fig, axes = plt.subplots(2, 8, figsize=(12, 4))
for i in range(8):
    axes[0, i].imshow(sample_data[i].squeeze(), cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
    axes[1, i].axis('off')

axes[0, 0].set_title("Original Images", fontsize=12)
axes[1, 0].set_title("Reconstructed Images", fontsize=12)
plt.show()

# Part 5 Load and use a pre-trained model

## Part 5.1 Load/Save Mechanism

In [None]:
# If you're happy with your model from training you can save it's progress or the final model with the following
torch.save(vae, 'trained_vae_model.pth')

## IF YOU WANT TO USE A PRE-TRAINED MODEL AFTER HERE THERE IS ONE PROVIDED

# Load the entire model
vae_entire_loaded = torch.load('trained_vae_model.pth')
vae_entire_loaded.eval()  # Set to evaluation mode

## Part 5.2 Lets use our pre-trained model to generate some images

### Part 5.2.1 Lets construct some random vectors in Prob-Space

In [None]:
# Generate New Images
num_images = 8

In [None]:
# Sample random latent vectors from standard normal distribution
rand_mu = torch.randn(num_images, ## FINISH_ME ##
rand_log = torch.randn(num_images, ## FINISH_ME ##

latent_vectors = reparameterize( ## FINISH_ME ##

### Part 5.2.2 Now we have some random vectors use the Decoder to build some images

In [None]:
# Use the decoder to generate images
with torch.no_grad():
    latent_vectors = latent_vectors.to(device)
    generated_images = vae_entire_loaded.decoder( ## FINISH_ME ##
    generated_images = generated_images.view(-1, 28, 28)  # Reshape to (28x28)
    generated_images = generated_images.detach().cpu().numpy()

### Part 5.2.3 Now lets plot the output of these images

In [None]:
# Plot Generated Images
plt.figure(figsize=(12, 3))
for i in range(num_images):
    plt.subplot(1, num_images, i+1)
    plt.imshow(generated_images[i], cmap='gray')
    plt.axis('off')

plt.suptitle("Generated Images from Pre-trained VAE", fontsize=16)
plt.show()

# Part 6 Examining the trained model Latent-Space

## Part 6.1 Make sure our model is in evaluation mode

In [None]:
# Ensure model is in evaluation mode
vae.eval()

## Part 6.2 Iterate over the test dataset and collect the distribution of latent vectors & losses

In [None]:
# Collect latent representations and labels
latent_vectors = []
labels = []
all_losses = []

In [None]:
with torch.no_grad():
    for x, y in test_loader:
        # This time we want to track which number ended up where in the latent-space
        x = x.to(device)

        # Evaluate the encoder to get the latent-space vector 
        mu, logvar = ## FINISH_ME ##
        z = reparameterize(mu, logvar)  # Use reparameterization trick to get final z

        for image in x:
            recon_img = vae.decoder(z)
            img_loss, _ = loss_function( ## FINSH_ME ##
            all_losses.append(img_loss.item())

        # Store the vectors and labels for plotting
        latent_vectors.append(z.cpu().numpy())
        labels.append(y.numpy())

## Part 6.3 Plot the distribution of per-image losses

In [None]:
# Lets plot the distribution of losses for the whole test dataset
plt.figure(figsize=(10, 6))
plt.hist(all_losses, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
plt.title("Distribution of VAE Losses on Test Dataset", fontsize=16)
plt.xlabel("Loss Value")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()

### What would you expect if we passed a b&w image of a face through our AE?

`## FINISH_ME ##`

## Part 6.4 Explore the Latent-Space distribution for all images

In [None]:
# Convert lists to arrays
latent_vectors = np.concatenate(latent_vectors, axis=0)
labels = np.concatenate(labels, axis=0)

In [None]:
## This code just gets the different combination of projections from n-dim down to 3D

# Get all 3D combinations from our latent space
latent_dim = latent_vectors.shape[1]
three_d_combinations = list(itertools.combinations(range(latent_dim), 3))

# Create figure grid to plot all projections
## Need to know how many projections we need, for 10-dim LS this equates to 120 possible images(!)
num_projections = len(three_d_combinations)
cols = 4  # Number of columns in the grid
rows = (num_projections // cols) + (num_projections % cols > 0)

In [None]:
# Lets make a large meta-image to start with
fig = plt.figure(figsize=(cols * 5, rows * 5))

# Iterate through all possible 3D projections
for i, (dim1, dim2, dim3) in enumerate(three_d_combinations):

    # Lets start a new sub-plot
    ax = fig.add_subplot(rows, cols, i+1, projection='3d')

    # For each number class lets run through our latent-vector values
    for digit in range(10):
        # This is a way of selecting only the images corresponding to this class
        mask = labels == digit
        # Lets draw a new class of images scattered within this 3D latent space
        ax.scatter(latent_vectors[mask, dim1], latent_vectors[mask, dim2], latent_vectors[mask, dim3],
                   label=f"{digit}", alpha=1.0, edgecolors='none', s=5)

    # Lets add some info about this sub-plot
    ax.set_xlabel(f"Dim {dim1+1}")
    ax.set_ylabel(f"Dim {dim2+1}")
    ax.set_zlabel(f"Dim {dim3+1}")
    ax.set_title(f"Projection ({dim1+1}, {dim2+1}, {dim3+1})")
    ax.legend(loc="upper right", fontsize=12)

# Plot the final image
plt.tight_layout()
plt.show()

# Part 7 Train a VAE over CIFAR10

I've included the bare-minimum to load the CIFAR10 dataset using the PyTorch built-ins.

This dataset is composed of 3x32x32 images which we can use to train a VAE.

This means you will likely need more CPU time than training on the mnist numerical data but feel free to try other model dimensions and such to see how well you can build a VAE for CIFAR10.

In [None]:
# Basic Transform (Convert to Tensor & Normalize)
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB channels
])


In [None]:
# Load Training Dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Load Test Dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


In [None]:
# DataLoader for batching & shuffling
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)


In [None]:
 ## FINISH_ME ##