# Autoencoders

<p>
CAS on Advanced Machine Learning <br>
Data Science Lab, University of Bern, 2024<br>
Prepared by Dr. Mykhailo Vladymyrov.

</p>

This work is licensed under a <a href="https://creativecommons.org/licenses/by-sa/4.0/">Creative Commons Attribution-ShareAlike 4.0 International License</a>.

# Libs and utils

In [None]:
# on colab:
# !pip install einops
# !pip install mlflow
# !pip install optuna


In [None]:
import os

# set env var to allow duplicated lib
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

This code snippet is doing the following:

#import os - Imports Python's built-in operating system module, which provides functions for interacting with the operating system.
#os.environ['KMP_DUPLICATE_LIB_OK']='TRUE' - Sets an environment variable called 'KMP_DUPLICATE_LIB_OK' to 'TRUE'.

#This specific environment variable is related to Intel's Math Kernel Library (MKL), which PyTorch often uses for performance optimization on Intel CPUs. The setting is addressing a known issue where the Intel OpenMP runtime library might be loaded multiple times, which can cause warnings or errors on some systems (particularly macOS).
#By setting this to 'TRUE', you're essentially telling the system to ignore the duplicate library loading issue. This is a common workaround when using PyTorch on macOS to prevent warnings about duplicate libraries being loaded.
#This line doesn't affect the functionality of your model, but rather helps avoid environment-related warnings or errors during execution.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import einops as eo
import pathlib as pl

import matplotlib.cm as cm
from matplotlib import collections  as mc
from matplotlib import animation
%matplotlib inline

from scipy.stats import norm
from scipy.stats import entropy

import pandas as pd
import pickle
from PIL import Image
from time import time as timer
#import umap

from IPython.display import HTML
from IPython.display import Audio
import IPython

import tqdm.auto as tqdm

import torch
from torchvision import datasets, transforms
from torch import nn
from torch import optim
import torch.nn.functional as F

import sys
is_colab = 'google.colab' in sys.modules


In [None]:
# get mean and std of an array with numpy:
def get_mean_std(x):
    x_mean = np.mean(x)
    x_std = np.std(x)
    return x_mean, x_std

# get min and max of an array with numpy:
def get_min_max(x):
    x_min = np.min(x)
    x_max = np.max(x)
    return x_min, x_max

def is_iterable(obj):
    try:
        iter(obj)
    except Exception:
        return False
    else:
        return True

#This function checks if an object is iterable (can be looped over).
#It uses a try-except block to attempt to call iter(obj), which will succeed only if obj is iterable.
#If calling iter(obj) raises any exception, the function returns False.
#If no exception occurs, the function returns True.

def type_len(obj):
    t = type(obj)
    if is_iterable(obj):
        sfx = f', shape: {obj.shape}' if t == np.ndarray else ''
        print(f'type: {t}, len: {len(obj)}{sfx}')
    else:
        print(f'type: {t}, len: {len(obj)}')

#This is a utility function for debugging that prints information about an object.
#t = type(obj) - Gets the type of the provided object.
#It checks if the object is iterable using the is_iterable function defined earlier.
#If the object is iterable:
#
#It checks if the object is a NumPy array (t == np.ndarray).
#If it's a NumPy array, it adds shape information to the output string.
#It prints the type and length of the object, along with shape information if applicable.
#
#
#If the object is not iterable, it still attempts to print the type and length (though this might raise an error if len() isn't applicable to the object).
#
#Note: There seems to be an issue with the type_len function - it tries to call len() on non-iterable objects in the else clause, 
#which would typically cause an error. This might be a bug in the code.


In [None]:
def to_np_showable(pt_img):
  np_im = pt_img.detach().cpu().numpy()
  if len(np_im.shape) == 4:
    np_im = np_im[0]

  if np_im.shape[0] > 3:
    np_im = np_im[-3:]

  return (eo.rearrange(np_im, 'c h w -> h w c')/2+.5).clip(0., 1.)

#This function converts a PyTorch tensor image to a NumPy array suitable for visualization.
#pt_img.detach().cpu().numpy() - Detaches the tensor from the computation graph, moves it to CPU if it's on GPU, and converts it to a NumPy array.
#if len(np_im.shape) == 4: - Checks if the image has a batch dimension (shape: [batch, channels, height, width]).
#np_im = np_im[0] - If there's a batch dimension, takes only the first image in the batch.
#if np_im.shape[0] > 3: - Checks if there are more than 3 channels.
#np_im = np_im[-3:] - If there are more than 3 channels, keeps only the last 3 channels (useful for handling multi-channel data).
#eo.rearrange(np_im, 'c h w -> h w c') - Uses the einops library to rearrange the tensor from PyTorch's [channels, height, width] format to matplotlib's [height, width, channels] format.
#/2+.5 - Applies normalization assuming the image data is in the range [-1, 1], converting it to [0, 1].
#.clip(0., 1.) - Ensures all values are within the [0, 1] range, clamping any values outside this range.

def plot_im(im, is_torch=True):
  plt.imshow(to_np_showable(im) if is_torch else im, cmap='gray')
  plt.show()
  plt.close()

#This function plots a single image.
#is_torch=True - Default parameter indicating whether the input is a PyTorch tensor.
#to_np_showable(im) if is_torch else im - Converts the image to a NumPy array if it's a PyTorch tensor, otherwise uses it directly.
#plt.imshow(..., cmap='gray') - Displays the image using matplotlib with a grayscale colormap.
#plt.show() - Renders the plot.
#plt.close() - Closes the figure to free up memory.

def plot_im_samples(ds, n=5, is_torch=False):
  fig, axs = plt.subplots(1, n, figsize=(16, n))
  for i, image in enumerate(ds[:n]):
      axs[i].imshow(to_np_showable(image) if is_torch else image, cmap='gray')
      axs[i].set_axis_off()
  plt.show()
  plt.close()


#This function plots multiple images from a dataset in a row.
#ds - The dataset or collection of images to sample from.
#n=5 - Default number of images to display.
#is_torch=False - Default parameter indicating whether the inputs are PyTorch tensors.
#plt.subplots(1, n, figsize=(16, n)) - Creates a figure with a single row of n subplots, with a width of 16 inches and height of n inches.
#The loop iterates through the first n images in the dataset:
#
#axs[i].imshow(...) - Displays each image in its corresponding subplot.
#axs[i].set_axis_off() - Removes axis labels and ticks for cleaner visualization.
#
#
#plt.show() - Renders the entire plot with all images.
#plt.close() - Closes the figure to free up memory.

In [None]:
# merging 2d matrix of images in 1 image
def mosaic(mtr_of_ims):
  ny = len(mtr_of_ims)
  assert(ny != 0)
  #Gets the number of rows in the matrix and asserts that it's not empty.

  nx = len(mtr_of_ims[0])
  assert(nx != 0)
  #Gets the number of columns in the first row and asserts that it's not empty.

  im_sh = mtr_of_ims[0][0].shape

  assert (2 <= len(im_sh) <= 3)
  #Gets the shape of the first image in the matrix.
  #Verifies that the image is either 2D (grayscale) or 3D (with channels).
    
  multichannel = len(im_sh) == 3

  if multichannel:
    h, w, c = im_sh
  else:
    h, w = im_sh
  #Determines if the images have multiple channels.
  #If multichannel, unpacks height, width, and channels. Otherwise, just height and width.

  h_c = h * ny + 1 * (ny-1)
  w_c = w * nx + 1 * (nx-1)
  #Calculates the total height and width of the canvas.
  #Adds 1 pixel spacing between images (both horizontally and vertically).

  canv_sh = (h_c, w_c, c) if multichannel else (h_c, w_c)
  canvas = np.ones(shape=canv_sh, dtype=np.float32)*0.5
  #Defines the shape of the canvas based on whether images are multichannel.
  #Creates a canvas filled with gray (0.5) values, assuming image values are in [0,1] range.

  for iy, row in enumerate(mtr_of_ims):
    y_ofs = iy * (h + 1)
    #Loops through each row of images.
    #Calculates the vertical offset for the current row.
    for ix, im in enumerate(row):
      x_ofs = ix * (w + 1)
      #Loops through each image in the current row.
      #Calculates the horizontal offset for the current image.
      canvas[y_ofs:y_ofs + h, x_ofs:x_ofs + w] = im
      #Copies the current image to the appropriate position in the canvas.
      #This uses NumPy's array slicing to place the image at the correct location.
  return canvas

In [None]:
# Verifica se 'mps' è disponibile su Apple Silicon

device = torch.device("mps" if torch.backends.mps.is_built() else "cpu")

 

# Conferma del dispositivo selezionato

print("Using device:", device)

In [None]:
device

# Setup

## Load dataset

Lets start with a simple, well understood mnist dataset.

In [None]:
NOISE_RATE = 0.1
N_SAMPLE = 32
N_VIS_SAMPLE = 2
BATCH_SIZE = 128

In [None]:
def collate_ae_dataset(samples):
    """
    The function collates sampels into a batch, and creates noisy samples if DENOISING is True
    for the denoising autoencoder.
    """
    xs = [s[0] for s in samples]
    ys = [s[1] for s in samples]
    #Extracts the first element (input data) from each sample into list xs.
    #Extracts the second element (labels or targets) from each sample into list ys.
    #This assumes each sample is a tuple or list with at least two elements.
    
    xs = torch.stack(xs)
    ys = torch.concat(ys)
    #torch.stack(xs) combines the list of input tensors into a single 
    #tensor along a new dimension (creating a batch dimension).
    #torch.concat(ys) concatenates the label tensors along 
    #the existing first dimension. This suggests the labels might have 
    #variable lengths or already include a batch-like dimension.

    add_noise = NOISE_RATE > 0.
    #Checks if noise should be added based on a global variable NOISE_RATE.
    #If NOISE_RATE is greater than 0, noise will be added to the inputs.
    
    if add_noise:
      sh = xs.shape
      noise_mask = torch.bernoulli(torch.full(sh, NOISE_RATE))  # 0 (keep) or 1 (replace with noise)
      #Gets the shape of the input tensor batch.
      #Creates a binary mask using Bernoulli sampling, where each element has NOISE_RATE probability of being 1 
      #(indicating where noise will be applied) and 1-NOISE_RATE probability of being 0.
            
      sp_noise = torch.bernoulli(torch.full(sh, 0.5))-0.5  # -1 or 1
      #Generates the actual noise values as either -0.5 or 0.5.
      #First creates a tensor of the same shape filled with 0.5, then applies Bernoulli 
      #sampling to get 0s or 1s.
      #Subtracts 0.5 to convert to -0.5 or 0.5 (this creates salt and pepper noise).
        
      xns = xs * (1-noise_mask) + sp_noise * noise_mask
      #Creates the noisy input xns by:
          #Keeping original values where the mask is 0: xs * (1-noise_mask)
          #Adding noise values where the mask is 1: sp_noise * noise_mask
          #The result is a tensor where some values are preserved 
          #from the original input and others are replaced with noise.
      
      # sp = sp_noise
    else:
       xns = xs
    #If no noise is to be added, the noisy input is the same as the original input.

    return xns.to(device), xs.to(device), ys.to(device)
    #Returns three tensors, all moved to the specified device (likely GPU):

    #xns: The inputs with noise added (or original inputs if no noise)
    #xs: The original clean inputs
    #ys: The labels or targets
    #
    #
    #This return structure is typical for denoising autoencoders, where you need 
    #both the noisy input (fed to the encoder) and the clean target 
    #(used to compute the reconstruction loss).
    #
    #This function is specifically designed for training denoising autoencoders, 
    #where the model learns to remove noise from corrupted inputs by trying 
    #to reconstruct the original clean data.

In [None]:
m, s = 0.5, 1.
# m, s = 0.5, 0.5
#m, s = 0., 1.

#Defines normalization parameters for the images: mean (m) and standard deviation (s).
#The active values are m=0.5 and s=1.0.
#The commented lines show alternative normalization parameters that were tried.

transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Pad(2), # to make images 32x32
    transforms.Normalize((m,), (s,))
])
#Creates a transformation pipeline for the images:
   #
   #transforms.ToTensor(): Converts PIL images to PyTorch tensors and scales pixel values from [0, 255] to [0, 1].
   #The commented line would add padding to make images 32×32 (original MNIST is 28×28).
   #transforms.Normalize((m,), (s,)): Normalizes the tensor with the specified mean and standard deviation, transforming values to (pixel - m) / s.
   #With m=0.5, s=1.0, this shifts the [0, 1] pixel values to [-0.5, 0.5].

lable_transform = transforms.Compose([lambda x:torch.LongTensor([x])])
#Creates a transformation for the labels:
   #
   #Takes a numeric label and converts it to a PyTorch tensor of type Long (integer).
   #Wraps it in a list to create a 1D tensor with a single value.

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform, target_transform=lable_transform)
valid_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform, target_transform=lable_transform)
#Creates two MNIST dataset objects:
    #
    #train_dataset: Uses the training split of MNIST (60,000 images).
    #valid_dataset: Uses the test split of MNIST (10,000 images) for validation.
    #Both will download the dataset if not already present in './data'.
    #Both apply the previously defined transforms to images and labels.


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_ae_dataset, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_ae_dataset, drop_last=True)
#Creates DataLoader objects for efficient batch processing:
    #
    #train_loader: For training data, with shuffling enabled to randomize the order.
    #valid_loader: For validation data, without shuffling to maintain order.
    #Both use the custom collate_ae_dataset function we analyzed earlier, which adds noise to inputs for denoising autoencoder training.
    #drop_last=True: Discards the last batch if it's smaller than the specified batch size.
    #BATCH_SIZE is a global variable defined elsewhere in the code.
    #
    #
    #
    #This setup prepares MNIST data specifically for training a denoising autoencoder, with the ability to add controlled noise to the input images while keeping the original clean images as targets.


In [None]:
# # given a AE model `model`
# for img, label in valid_dataset:
#     reconstruction = model(img)
#     loss_value = loss(img, reconstruction).item()

In [None]:
valid_dataset

In [None]:
type(valid_dataset)

In [None]:
#This code loads and examines a single sample from the validation dataset:
for sample in valid_dataset:
    img, label = sample
    print(type_len(img))
    print(type_len(label))
    print(img.shape, label.shape)
    plt.hist(img.flatten(), bins=100)
    break
#Here's what each line does:
    #
    #for sample in valid_dataset: - Starts iterating through the validation dataset.
    #img, label = sample - Unpacks the first sample into its image and label components. 
    #Each sample from the dataset is a tuple of (image, label).
    #print(type_len(img)) - Calls the previously defined type_len function on the image tensor. 
    #This will print information about the type and shape of the image.
    #print(type_len(label)) - Similarly, prints type and shape information for the label.
    #print(img.shape, label.shape) - Directly prints the shapes of both the image and label tensors.
    #plt.hist(img.flatten(), bins=100) - Creates a histogram of all pixel values in the image:
    #
    #img.flatten() converts the multi-dimensional tensor into a 1D array
    #bins=100 divides the range of values into 100 bins for the histogram
    #This visualizes the distribution of pixel values after the transformations (ToTensor and Normalize)
    #
    #
    #break - Exits the loop after processing just the first sample, so only one image is examined.
    #
    #The histogram will show the distribution of normalized pixel values, which should be centered around 0 (since the normalization used m=0.5 and s=1.0, shifting the original [0,1] values to approximately [-0.5, 0.5]).
    #This code snippet is useful for understanding the exact format and value ranges of the data after all transformations have been applied, which is important for debugging and confirming that preprocessing is working as expected.

In [None]:
#This line of code creates a histogram of the raw MNIST validation dataset pixel values:

plt.hist(valid_dataset.data.numpy().flatten(), bins=100);

#Let's break it down:
#
#valid_dataset.data - Accesses the raw image data from the MNIST validation dataset. This is the original pixel data before any transformations are applied.
#.numpy() - Converts the PyTorch tensor to a NumPy array, since the plotting functions in matplotlib work with NumPy arrays.
#.flatten() - Converts the multi-dimensional array (which has shape [10000, 28, 28] for the 10,000 MNIST test images, each 28×28 pixels) into a 1D array with all pixel values.
#plt.hist(..., bins=100) - Creates a histogram with 100 bins, showing the distribution of all pixel values.
#The semicolon at the end suppresses any additional output that might be generated.
#
#This histogram will show the original, untransformed distribution of pixel values in the MNIST dataset. Since MNIST consists of grayscale images with pixel values from 0 to 255, you'll see a distribution heavily skewed toward 0 (black), with a much smaller number of pixels having higher values (gray to white). This is typical for MNIST, which has black backgrounds and white digit strokes.
#This is useful for comparing against the transformed data to visualize exactly how your normalization is affecting the pixel value distribution.

## Properties of PyTorch dataset classes: 
### Me:
Let me ask a question about the torchvision.dataset: Looping through valid_dataset yields the transformed data whereas valid_dataset.data accesses the original untransformed data, if my understanding is correct. Is this correct? Is this a general property of the torchvision dataset type (am I correct in calling this object torchvision data set?)?Edit

### Claude: 
Yes, your understanding is correct. When you loop through a PyTorch dataset object like valid_dataset, you get the transformed data (after all transforms have been applied), while accessing valid_dataset.data gives you the original, untransformed data.
This is indeed a general property of torchvision datasets. The datasets.MNIST class (and other dataset classes in torchvision.datasets) are designed this way deliberately:

valid_dataset.data - Contains the raw, untransformed tensor data
valid_dataset.targets - Contains the raw, untransformed labels
When you iterate through the dataset or use valid_dataset[index], you get the data after both the transform and target_transform functions have been applied

This design allows you to:

Apply transformations efficiently (only when data is actually accessed)
Access the original data when needed (for visualization or debugging)
Maintain a single copy of the dataset in memory

This separation between raw data and transformed data is a general pattern in PyTorch's dataset classes, not just in torchvision. Custom datasets typically follow the same pattern, where the __getitem__ method applies transformations on-the-fly, but the raw data is still accessible through class attributes.

In [None]:
valid_loader

In [None]:
for sample in valid_loader:
    #- Begins iteration through the validation data loader, which provides batches of data.

    noisy_img, img, label = sample
    #Unpacks the first batch from the data loader into three components:
        #noisy_img: The input images with noise added (for denoising autoencoder training)
        #img: The original clean images (targets for reconstruction)
        #label: The class labels for the images
    print(type_len(noisy_img))
    print(type_len(img))
    print(type_len(label))
    
    print(noisy_img.shape, img.shape, label.shape)
    #Directly prints the shapes of all three tensors in the batch.
    
    #plt.hist(img.flatten(), bins=100)
    #A commented-out line that would create a histogram of all pixel values in the clean image batch if uncommented.
    
    break
    #Exits the loop after processing just the first batch.


#This code is useful for confirming that the data loader and the custom collate 
    #function are working as expected, providing the three tensors in the expected shapes. 
    #It's particularly important to verify that the noisy and clean images have the same shape, 
    #and that the batch size matches what was specified in the DataLoader setup.


### Me: 
Thanks for your explanations. Could you also explain properties and behaviour of the valid_loader object? It is the output of the torch.utils.data.DataLoader() function applied to the valid_dataset object if I understand correctly.


### Claude: 
You're welcome! Let me explain the properties and behavior of the valid_loader object.
Yes, valid_loader is indeed the output of the torch.utils.data.DataLoader() function applied to the valid_dataset object. The DataLoader is a crucial PyTorch utility that wraps a dataset and provides several important features:
Key Properties and Behaviors of valid_loader:

#### Batching:

Takes individual samples from the dataset and combines them into batches
The batch size was defined by the batch_size=BATCH_SIZE parameter
Each iteration returns a batch instead of a single sample


#### Custom Collation:

Uses the collate_fn=collate_ae_dataset parameter to customize how samples are combined
In your case, this is where noise is added to create the noisy images for denoising autoencoder training
This function transforms the standard (image, label) pairs into (noisy_image, clean_image, label) triplets


#### No Shuffling:

shuffle=False means samples are processed in sequential order
This is typical for validation sets where randomization isn't needed


#### Dropping Incomplete Batches:

drop_last=True means that if the number of samples isn't divisible by the batch size, the last incomplete batch is discarded
This ensures all batches have consistent size, which can be important for some operations


#### Iteration Interface:

Can be iterated through using a for loop: for batch in valid_loader:
Each iteration yields one batch of data
For validation data, you typically iterate through the entire loader once per epoch


#### Memory Efficiency:

Loads data in batches on-demand rather than loading the entire dataset at once
Applies transformations on-the-fly


#### Multi-processing Support (if enabled):

Can load data in parallel using multiple worker processes
This wasn't explicitly set in your code, so it's using the default (single process)


#### Length Property:

len(valid_loader) would return the number of batches
With drop_last=True, this equals floor(len(valid_dataset) / BATCH_SIZE)


#### Device Handling:

The collate function in your code moves data to the specified device (likely GPU)
This saves you from explicitly moving data to the right device in your training loop



When you iterate through valid_loader, you get batches in the format determined by your custom collate function. In your case, each batch contains three tensors: noisy images, clean images, and labels, all already moved to the target device and ready for model processing.

In [None]:
for s in train_loader:
  xns, xs, ys = s
  print(xns.shape, xs.shape, ys.shape)
  break

In [None]:
plot_im_samples(xns, is_torch=True)
plot_im_samples(xs, is_torch=True)

In [None]:
# fill array of all preprocessed training samples, converted to numpy:
train_images = []
# - Initializes an empty list to store batches of images.

for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
# Iterates through all batches in the training data loader:
# batch_idx is the index of the current batch
# noisy_data, data, target unpacks each batch into its three components
#enumerate() provides both the batch index and the batch contents
    
    train_images.append(data.detach().cpu().numpy())
    # For each batch:
        # data.detach() removes the tensor from the computation graph (no gradients needed)
        # .cpu() moves the tensor to CPU memory if it was on another device like GPU
        # .numpy() converts the PyTorch tensor to a NumPy array
        # The result is appended to the train_images list

train_images = np.concatenate(train_images, axis=0)
# - Combines all batches into a single NumPy array:
# axis=0 specifies to concatenate along the first dimension (the batch dimension)
# The result is a single large array containing all training images


# This process creates a complete NumPy array of all the training images, which could be useful for:
# 
# Performing dataset-wide statistics or analysis
# Visualizing the entire dataset distribution
# Pre-computing values for normalization
# Creating custom visualizations of the training data
# 
# Note that this operation can be memory-intensive for large datasets, as it loads all images into memory at once, unlike the DataLoader which processes data in batches.
# 

In [None]:
print ("train_images.shape = ", train_images.shape)
print ("train_images.dtype = ", train_images.dtype)
print ("train_images.mean/std() = ", get_mean_std(train_images))
print ("train_images.min/max() = ", get_min_max(train_images))

In [None]:
# This line of code creates a histogram visualization of all pixel values in the training dataset:
plt.hist(train_images.flatten(), bins=100, log=True);

# Breaking it down:
# 
# train_images.flatten() - Converts the multi-dimensional array of all training images (which likely has shape [num_samples, 1, 28, 28] for MNIST) into a one-dimensional array containing every pixel value from all images.
# plt.hist(..., bins=100, log=True) - Creates a histogram with the following properties:
# 
# bins=100 divides the range of pixel values into 100 equal-width bins
# log=True uses a logarithmic scale for the y-axis (frequency count), which helps visualize the distribution when some bins have very high counts compared to others
# 
# 
# The semicolon at the end suppresses any additional output that might be generated.
# 
# This visualization shows the distribution of all pixel values across the entire training dataset after preprocessing (normalization). With the normalization parameters you used (m=0.5, s=1.0), the histogram will likely show a bimodal distribution:
# 
# A large peak around -0.5 (corresponding to the black background pixels that were originally near 0)
# A smaller peak around 0.5 (corresponding to the white digit strokes that were originally near 255)
# 
# The logarithmic scale helps to see both the very common background pixel values and the less common digit pixel values in the same visualization.

In [None]:
del train_images

We will also prepare a subsampled dataset from the validation set for the visualisation purpose.

In [None]:
def get_samples(valid_loader):
  # 1. get numpy array of all validation images:
  val_images_noisy = []
  val_images = []
  val_labels = []

  for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
      val_images_noisy.append(noisy_data.detach().cpu().numpy())
      val_images.append(data.detach().cpu().numpy())
      val_labels.append(target.detach().cpu().numpy())

  val_images_noisy = np.concatenate(val_images_noisy, axis=0)
  val_images = np.concatenate(val_images, axis=0)
  val_labels = np.concatenate(val_labels, axis=0)

  # 2. get numpy array of balanced validation samples for visualization:
  sample_images_noisy = []
  sample_images = []
  sample_labels = []
  single_el_idx = []  # indexes of single element per class for visualization

  n_class = np.max(val_labels) + 1
  # Determines the number of classes (for MNIST, this would be 10, representing digits 0-9).
  for class_idx in range(n_class):
    map_c = val_labels == class_idx

    ims_c_noisy = val_images_noisy[map_c]
    ims_c = val_images[map_c]
    # For each class:
       # Creates a boolean mask map_c identifying all samples of the current class.
       # Extracts noisy and clean images for just this class.
      

    samples_idx = np.random.choice(len(ims_c), N_SAMPLE, replace=False)

    ims_c_noisy_samples = ims_c_noisy[samples_idx]
    ims_c_samples = ims_c[samples_idx]
    # Randomly selects N_SAMPLE images from the current class.
    # replace=False ensures no duplicates are selected.
    # Extracts both noisy and clean versions of these sampled images.
      

    sample_images_noisy.append(ims_c_noisy_samples)
    sample_images.append(ims_c_samples)

    sample_labels.append([class_idx]*N_SAMPLE)

    # Adds the sampled noisy images, clean images, and labels to their respective lists.
    # Creates an array of N_SAMPLE repeated labels for this class.

    start_idx = N_SAMPLE*class_idx
    single_el_idx.extend([start_idx + i for i in range(min(N_VIS_SAMPLE, N_SAMPLE))])
    # Calculates the indices for the first N_VIS_SAMPLE elements of this class in the final concatenated array.
    # These indices will be used to extract a smaller subset for visualization.

    
  sample_images_noisy = np.concatenate(sample_images_noisy, axis=0)
  sample_images = np.concatenate(sample_images, axis=0)
  sample_labels = np.concatenate(sample_labels, axis=0)
  single_el_idx = np.array(single_el_idx)
  #Combines all class samples into single arrays.
  #Converts the index list to a NumPy array.

  samples = {
      'images_noisy': sample_images_noisy,
      'images': sample_images,
      'labels': sample_labels,
      'single_el_idx': single_el_idx

  }
  return samples
# Creates and returns a dictionary with all collected samples.


# This function ensures we have:
# 
# A balanced number of samples for each class (equal representation)
# Both noisy and clean versions of each image
# A mapping between the noisy and clean versions
# A subset of indices for visualization purposes
# This is particularly useful for creating visualizations that show how the model behaves across different classes, or for comparing reconstruction quality across digits.






In [None]:
samples = get_samples(valid_loader)

In [None]:
samples['images'].shape, samples['labels'].shape

In [None]:
single_el_idx = samples['single_el_idx']
plot_im_samples(samples['images_noisy'][single_el_idx, 0], n=20, is_torch=False)
plot_im_samples(samples['images'][single_el_idx, 0], n=20, is_torch=False)

## Helper Autoencoder Interface

We will start from implementing an Autoencoder model base class

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_size, code_size):
        self.input_size = list(input_size)  # shape of data sample
        self.flat_data_size = np.prod(self.input_size)
        self.hidden_size = 128

        self.code_size = code_size  # code size

        super(AutoEncoder, self).__init__()
        #Creates an autoencoder neural network that inherits from PyTorch's nn.Module.
        #Takes two parameters:
        #
        #input_size: The shape of input data (e.g., [1, 28, 28] for MNIST)
        #code_size: The dimension of the encoded representation (bottleneck)
        #
        #
        #Calculates the flattened input size by multiplying all dimensions.
        #Sets an intermediate hidden layer size of 128 neurons.
        #Calls the parent class initializer.

        
        self.encoder = nn.Sequential(
            nn.Flatten(),

            nn.Linear(self.flat_data_size, self.hidden_size),
            nn.ReLU(),

            nn.Linear(self.hidden_size, self.code_size),
            nn.Sigmoid(),
        )
        #Defines the encoder network as a sequence of operations:
            #
            #Flattens the input (e.g., converts a 2D image to 1D)
            #Linear layer mapping from input size to hidden size
            #ReLU activation
            #Linear layer mapping from hidden size to code size
            #Sigmoid activation (constrains the encoded values to [0, 1])
        
        self.decoder = nn.Sequential(
            nn.Linear(self.code_size, self.hidden_size),
            nn.ReLU(),

            nn.Linear(self.hidden_size, self.flat_data_size),
            nn.Tanh(),  # Think: why tanh?

            nn.Unflatten(1, self.input_size),
        )
        #Defines the decoder network:
            #Linear layer from code size to hidden size
            #ReLU activation
            #Linear layer from hidden size back to the flattened input size
            #Tanh activation (outputs values in [-1, 1], matching the normalized input range)
            #Unflattens the output back to the original input shape

#Regarding "why tanh?": Tanh is used because the input images were normalized to approximately [-0.5, 0.5] 
    #range (using m=0.5, s=1.0). Tanh outputs values in the range [-1, 1], 
    #which after scaling by 1.1 in the decode method closely matches the input data range.

    def forward(self, x, return_z=False):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return (decoded, encoded) if return_z else decoded
    # The forward pass:
        #Encodes the input
        #Decodes the encoded representation
        #If return_z=True, returns both the reconstruction and the encoded values
        #Otherwise, just returns the reconstruction
        

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)*1.1
# Helper methods to encode and decode separately
# Note the multiplication by 1.1 in the decode method, 
    # which slightly amplifies the output range to better match the input data distribution

        

    def get_n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    # Utility method to count the total number of trainable parameters in the model


def eval_on_samples(ae_model, epoch, samples):
    # this is called on end of each training epoch
    xns = samples['images_noisy']
    xns = torch.tensor(xns, dtype=torch.float32).to(device)
    #labels = samples['labels']

# Function to evaluate the autoencoder on sample data after each epoch
# Takes the model, current epoch number, and samples dictionary
# Extracts noisy images from the samples and converts them to a PyTorch tensor on the target device
# The labels are extracted but commented out (not used)

    with torch.no_grad():
        yz = ae_model(xns, return_z=True)
        yz = [el.detach().cpu().numpy() for el in yz]

        y = yz[0]
        z = yz[1:]
    # Uses torch.no_grad() to disable gradient calculation (for efficiency during evaluation)
    # Gets both reconstructions and encodings (i.e. latent space!) by calling the model with return_z=True
    # Converts all outputs to NumPy arrays
    # Separates the reconstructions y and encodings z

    res = {'z': z, 'y': y, 'epoch': epoch}
    return res

# Creates and returns a dictionary containing:

# z: The encoded representations
# y: The reconstructed images
# epoch: The current epoch number
# 

# This evaluation function captures the model's performance at each epoch, allowing for tracking reconstruction quality and analyzing the learned representations over time.

In [None]:
def plot_hist(history, logscale=True):
    """
    plot training loss
    """

    loss = history['loss']
    v_loss = history['val_loss']
    epochs = history['epoch']

    # This function visualizes training history (loss over epochs).
    # Extracts training loss, validation loss, and epoch numbers from the history dictionary.

    
    plot = plt.semilogy if logscale else plt.plot
    # Cleverly chooses between logarithmic scale (plt.semilogy) or linear scale (plt.plot) based on the logscale parameter.
    # Default is logarithmic scale, which is often better for visualizing loss curves as they typically decrease exponentially.
    
    plot(epochs, loss, label='training');
    plot(epochs, v_loss, label='validation');
    # Plots both training and validation loss curves using the selected plotting function.
    # Labels each curve for the legend.
    
    plt.legend()
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.show()
    plt.close()
    # Adds a legend, axis labels, displays the plot, and then closes the figure.



def plot_samples(sample_history, samples, epoch_stride=5, fig_scale=1):
    """
    Plots input, noisy samples (for DAE) and reconstruction.
    Each `epoch_stride`-th epoch
    """
    # This function visualizes sample reconstructions over training epochs.
    # Shows how the model's reconstruction capability improves over time.

    single_el_idx = samples['single_el_idx']
    images_noisy = samples['images_noisy'][single_el_idx, 0]
    images = samples['images'][single_el_idx, 0]
    # Extracts indices for selected samples to visualize.
    # Gets the noisy input images and the original clean images for these samples.
    # The , 0 indexing suggests selecting the first channel of each image.

    last_epoch = np.max(list(sample_history.keys()))
    # Determines the last epoch number in the history data.

    for epoch_idx, hist_el in sample_history.items():
      if epoch_idx % epoch_stride != 0 and epoch_idx != last_epoch:
        continue
    # Iterates through each epoch's results in the history.
    # Uses epoch_stride to select only every nth epoch (to avoid too many visualizations).
    # Always includes the last epoch regardless of the stride.

      samples_arr = [images_noisy, hist_el['y'][single_el_idx, 0], images]
    # Creates an array of three sets of images to visualize side by side:
       # The noisy input images
       # The model's reconstructions for the current epoch
       # The original clean images (ground truth)

      ny = len(samples_arr)
      nx = len(samples_arr[0])

      plt.figure(figsize=(fig_scale*nx, fig_scale*ny))
      # Calculates the dimensions of the visualization grid.
      # Creates a figure with size proportional to the number of samples.

        
      m = mosaic(samples_arr)
      # Uses the previously defined mosaic function to create a grid of all images.

      plt.title(f'after epoch {int(epoch_idx)}')
      plt.imshow(m, cmap='gray', vmin=-.5, vmax=.5)
      # Adds a title showing which epoch this visualization represents.
      # Displays the mosaic with a grayscale colormap and fixed value range.
      # The vmin=-.5, vmax=.5 matches the normalized data range we've seen before.

        
      plt.tight_layout(pad=0.1, h_pad=0, w_pad=0)
      plt.show()
      plt.close()
      # Ensures proper spacing in the figure.
      # Displays the figure and then closes it to free memory.

# This function creates a powerful visualization showing the progression of the model's reconstruction ability across epochs. Each visualization has three rows:
# 
# The noisy inputs
# The model's reconstructions
# The original clean images (targets)
# 
# This makes it easy to see how the model gradually learns to denoise and reconstruct the images over the course of training.

In [None]:
# These are utility functions for working with trained models at different stages of training. Let me break them down:

def run_on_trained(model, root_dir, run_fn, ep=None, model_filename=None):
    """
    Helper function to excecute any function on model in state after `ep` training epoch
    """
    # This function loads a model checkpoint and runs a specified function on it.
    # Parameters:
    # 
    # model: The neural network model instance
    # root_dir: Directory containing saved model checkpoints
    # run_fn: The function to run on the loaded model
    # ep: Specific epoch to load (optional)
    # model_filename: Specific checkpoint file to load (optional)

    if model_filename is None:
        if ep is not None:
            model_filename = root_dir/f'model_{ep:03d}.pth'
        else:
            model_filename = sorted(list(root_dir.glob('*.pth')))[-1]  # last model state
    # Determines which model checkpoint file to load:
    # 
    # If a specific filename is provided, use that (in this case this code block would be skipped)
    # If an epoch number is provided, construct the filename using a pattern
    # If neither is provided, use the last checkpoint file (by alphabetical sorting)
    # The code uses pathlib's Path objects for file handling (using / for path joining)

    
    model_dict = torch.load(model_filename,weights_only=False)

    model.load_state_dict(model_dict['model_state_dict'])

    # Loads the saved model state from the specified file
    # The weights_only=False parameter indicates to load the full state dictionary (not just weights)
    # Restores the model parameters from the saved state dictionary
    

    run_fn(model)
    # Calls the provided function on the loaded model

def run_on_all_training_history(model, root_dir, run_fn, n_ep=None):
    """
    Helper function to excecute any function on model state after each of the training epochs
    """
    # This function runs a specified function on multiple model checkpoints from different training epochs.
    # Parameters:
    # 
    # model: The neural network model instance
    # root_dir: Directory containing saved model checkpoints
    # run_fn: The function to run on each loaded model state
    # n_ep: Specific number of epochs to process (optional)
    
    if n_ep is not None:
        for ep in range(n_ep):
            print(f'running on epoch {ep+1}/{n_ep}...')
            run_on_trained(model, root_dir, run_fn, ep=ep)
    # If a specific number of epochs is provided:
    # 
    # Iterates through each epoch from 0 to n_ep-1
    # Prints progress information
    # Calls run_on_trained for each epoch
    
    else:
        for model_filename in sorted(root_dir.glob('*.pth')):
            print(f'running on checkpoint {model_filename}...')
            run_on_trained(model, root_dir, run_fn, model_filename=model_filename)

    # If no specific number of epochs is provided:
    # 
    # Finds all .pth files in the root directory
    # Sorts them (presumably by name, which would be by epoch if using the naming pattern)
    # Processes each checkpoint file in order
    
    print(f'done')

    # Prints a completion message when all checkpoints have been processed
    # 
    # These utility functions make it easy to:
    # 
    # Analyze a model at a specific point in its training history
    # Run the same analysis across multiple stages of training
    # Visualize or evaluate how the model's behavior changes over the course of training
    # 
    # They're particularly useful for post-training analysis, debugging, and creating visualizations of model evolution.

In [None]:
# This code block initializes and tests the autoencoder model with a sample batch. 
# Let me explain it line by line:
train_batch = next(iter(train_loader))
# Gets the first batch from the training data loader without running a full epoch
# iter(train_loader) creates an iterator from the data loader
# next() retrieves the first element from that iterator (the first batch)
xns, xs, ys = train_batch
# Unpacks the batch into three components:
 # xns: The noisy input images
 # xs: The clean original images
 # ys: The class labels

print('sample shapes:', xns.shape, xs.shape, ys.shape)
# Prints the shapes of all three tensors to verify their dimensions
# Likely shows something like [batch_size, 1, 28, 28] for the images

in_size = xns.shape[1:]
print(in_size)
# Extracts the input size excluding the batch dimension
# For MNIST, this would be [1, 28, 28] (channels, height, width)

ae = AutoEncoder(input_size=in_size, code_size=10).to(device)
# Creates an instance of the AutoEncoder model:
# 
# input_size is set to the dimensions of the input data
# code_size=10 defines the bottleneck dimension (the size of the encoded representation)
# .to(device) moves the model to the appropriate device (CPU or GPU)

y = ae(xns)
# Performs a forward pass through the model with the noisy images
# The model attempts to reconstruct the clean images from the noisy ones
# Since return_z=False by default, this only returns the reconstructions

print('output shape:', y.shape)
# Prints the shape of the model's output
# Should match the input shape, as the autoencoder reconstructs the original dimensions

plot_im_samples(xns, is_torch=True)
# Visualizes a few of the noisy input images using the previously defined function

plot_im_samples(y, is_torch=True)
# Visualizes the corresponding reconstructed images
# This allows comparing the model's initial reconstructions before training -> The reconstructions are 
# just noise because the model has not been trained. 

In [None]:
# This code compares the pixel value distributions of an input image and its reconstruction. 
# Here's what each line does:

x = xns[0]# - y[1]
# Selects the first image from the batch of noisy inputs.
# Note that there's a commented-out subtraction (# - y[1])

d = y[0]# - y[1]
# Selects the first image from the batch of reconstructed outputs.
# Again, there's a commented-out subtraction

im0 = x[0].detach().cpu().numpy()
# Takes the first channel of the selected input image
# Detaches it from the computation graph (no gradients needed)
# Moves it to CPU if it was on GPU
# Converts it to a NumPy array

im1 = d[0].detach().cpu().numpy()
# Does the same conversion process for the reconstructed image

# plt.imshow(im, cmap='gray', vmin=-1, vmax=1)
# This is a commented-out visualization that would display the image

bins = np.linspace(-1, 1, 100)
# Creates 100 evenly spaced histogram bins from -1 to 1
# This range is chosen to match the expected range of pixel values

plt.hist(im0.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the input image
# flatten() converts the 2D image to a 1D array
# alpha=0.3 makes the histogram semi-transparent

plt.hist(im1.flatten(), bins, alpha=0.3);
# Creates a histogram of all pixel values in the reconstructed image
# Using the same bins and transparency
# Overlaid on the same plot as the input image histogram


# This visualization allows comparing the distribution of pixel values between 
# the noisy input and the reconstruction. It helps assess how well the autoencoder 
# is preserving the overall pixel value distribution and whether 
# it's correctly mapping values from the input distribution to the expected output distribution.
# The semi-transparent overlapping histograms make it easy to see differences 
# in how pixel values are distributed between the original and reconstructed images.

In [None]:
ae.get_n_params()

# Simple Autoencoder

First we train create an autoencoder with 5 latent variables for MNIST dataset.
Each image is 28x28 pixels.
We start from image data since it's easy to interpret and judge the reconstruction quality visually, but the very same applies to other data types.

In [None]:
# This code sets up the final model configuration and prepares the sample data for training:

CODE_SIZE = 5
# Sets the dimensionality of the encoded representation (bottleneck) to 5
# This is smaller than the previous test where code_size was 10, 
# creating a more compressed representation


NOISE_RATE = 0
# Sets the noise rate for the denoising autoencoder to 0
# This means no artificial noise will be added, making it 
# function as a standard autoencoder rather than a denoising one

MODEL_NAME = 'ae_model'
# Assigns a name to the model, likely used for saving checkpoints and organizing results

model = AutoEncoder(input_size=in_size, code_size=CODE_SIZE).to(device)
# Creates a new instance of the AutoEncoder with:
# 
# The previously determined input size (from the shape of the data)
# The newly defined CODE_SIZE of 5
# Placed on the appropriate device (CPU or GPU)

samples = get_samples(valid_loader)
# Calls the previously defined get_samples function to create a balanced set of 
# validation samples
# These samples will be used to monitor reconstruction quality during training
# The function selects representative samples from each class for visualization
# 
# This code block is preparing the final model configuration before training. 
# It's worth noting that with NOISE_RATE set to 0, this will train a standard 
# autoencoder rather than a denoising autoencoder, despite the earlier code 
# being set up to handle noise addition.


In [None]:
model

Train the model for 50 epochs

In [None]:
# train the autoencoder model, for N_EPOCHS epochs,
# save history of loss values for training and validation sets,
# history of validation samples evolution, and model weights history,


# This code implements the complete training loop for the autoencoder. 
# Let me break it down:

N_EPOCHS = 50
LR = 0.0009
# Sets the number of training epochs to 50
# Sets the learning rate for the Adam optimizer to 0.0009

model_root = pl.Path(MODEL_NAME)
model_root.mkdir(exist_ok=True)
# Creates a directory path for saving model checkpoints using the MODEL_NAME ('ae_model')
# Makes sure the directory exists (creates it if it doesn't)

optimizer = optim.Adam(model.parameters(), lr=LR)
# Creates an Adam optimizer to update the model parameters
# Adam is an adaptive learning rate optimization algorithm well-suited for deep learning

# implement loss explicitly
loss = nn.MSELoss()
# Defines the loss function as Mean Squared Error (MSE)
# This measures the average squared difference between the reconstructed and target images

# train the model
history = {'loss': [], 'val_loss': [], 'epoch': []}
sample_history = {}
# Creates dictionaries to store training metrics and sample reconstructions
# history tracks training and validation losses across epochs
# sample_history will store sample reconstruction results at each epoch

pbar = tqdm.tqdm(range(N_EPOCHS), postfix=f'epoch 0/{N_EPOCHS}')
# Creates a progress bar for tracking the training process
# Will show the current epoch and update metrics during training

for epoch_idx in pbar:
# Starts the main training loop that runs for N_EPOCHS iterations
    epoch_loss = 0
    model.train()
    # Initializes the epoch loss accumulator
    # Sets the model to training mode (enables dropout, batch normalization updates, etc.)
    
    for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(noisy_data)
        loss_value = loss(output, data)
        loss_value.backward()
        optimizer.step()
        epoch_loss += loss_value.detach().cpu().item()
    # Iterates through all batches in the training dataset
    # For each batch:
    # 
        # Zeros out previous gradients
        # Passes the noisy input through the model
        # Calculates the MSE loss between the reconstruction and clean data
        # Computes gradients via backpropagation
        # Updates model parameters using the optimizer
        # Accumulates the loss value for epoch-level reporting
    
    epoch_loss /= len(train_loader)
    history['loss'].append(epoch_loss)
    history['epoch'].append(epoch_idx)
    # update progress bar

    # Calculates the average loss for the epoch
    # Records the loss and epoch number in the history

    # evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
            output = model(noisy_data)
            loss_value = loss(output, data)
            val_loss += loss_value.detach().cpu().item()
        val_loss /= len(valid_loader)
        history['val_loss'].append(val_loss)

    # Sets the model to evaluation mode (disables dropout, etc.)
    # Disables gradient calculation for efficiency
    # Computes the validation loss on the entire validation set
    # Records the average validation loss in the history
    
    pbar.set_postfix({'epoch': f'{epoch_idx+1}/{N_EPOCHS}', 'loss':f'{epoch_loss:.4f}', 'val_loss':f'{val_loss:.4f}'})
    # evaluate on samples
    # Updates the progress bar with current epoch, training loss, and validation loss
    
    sample_res = eval_on_samples(model, epoch_idx, samples=samples)
    # This saves the reconstructions and the latent space thanks to
    # the eval_on_samples function where in the application of the 
    # model to the evaluation data the return_z parameter is set 
    # to true: 
    # with torch.no_grad():
    #     yz = ae_model(xns, return_z=True)
    #     yz = [el.detach().cpu().numpy() for el in yz]
# 
    #     y = yz[0]
    #     z = yz[1:]
    
    # The output of eval_on_samples looks like this: 

    # sample_res = {'z': z, 'y': y, 'epoch': epoch}

    
    sample_history[epoch_idx] = sample_res
    # Evaluates the model on the sample images
    # Stores reconstructions for later visualization

    # save model weights
    torch.save({
                'epoch': epoch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, model_root/f'model_{epoch_idx:03d}.pth')

    # Saves a checkpoint of the model at each epoch
    # The checkpoint includes:
    # 
    # Current epoch number
    # Model parameters
    # Optimizer state (allows resuming training)
    # Loss function
    # 
    # 
    # Uses a formatted filename with padded epoch number (e.g., 'model_001.pth')
# 
# This is a complete training pipeline that not only trains the model 
# but also tracks metrics, evaluates on validation data, 
# and creates visualizations to monitor progress - 
# all while saving checkpoints for later analysis or resuming training.
    

Plot loss function evolution during training:

In [None]:
plot_hist(history)

Let's visually compare network's output with the input

In [None]:
plot_samples(sample_history, samples=samples, epoch_stride=5, fig_scale=1)

In [None]:

sample_history[49]['z'][0].T.shape  # 320 samples (32 each class) x 5 elements

In [None]:
for d in sample_history[49]['z'][0].T:
  plt.hist(d, 100, alpha=0.3);

# Denoising Autoencoder

In denoising AE we create noisy samples and ask the model to reconstruct original (clean) samples. Here 20% of pixels will be corrupted.
This forces the model to learn more robust representations.

In [None]:
# This code sets up a new model configuration, 
# specifically changing it from a standard autoencoder 
# to a denoising autoencoder:

CODE_SIZE = 5
# Keeps the dimensionality of the encoded representation 
# (bottleneck) at 5, same as before

NOISE_RATE = 0.2
# Sets the noise rate to 0.2, meaning approximately 20% of pixels 
# in each input image will be corrupted with noise
# This is the key change that transforms the model from a standard 
# autoencoder to a denoising autoencoder
# With this setting, the model will learn to reconstruct clean images 
# from artificially corrupted ones

MODEL_NAME = 'dae_model'
# Changes the model name to 'dae_model' (denoising autoencoder model) to distinguish it from the previous 'ae_model'
# This will create a separate directory for saving checkpoints of this model

model = AutoEncoder(input_size=in_size, code_size=CODE_SIZE).to(device)
# Creates a fresh instance of the AutoEncoder with the same architecture as before
# Uses the previously determined input size and the CODE_SIZE of 5
# Moves the model to the appropriate device (CPU or GPU)

samples = get_samples(valid_loader)
# Regenerates the validation samples for monitoring training progress
# Since NOISE_RATE has changed, these samples will now include the 20% noise corruption


In [None]:
model

Train the model for 50 epochs

In [None]:
# train the autoencoder model, for N_EPOCHS epochs,
# save history of loss values for training and validation sets,
# history of validation samples evolution, and model weights history,

N_EPOCHS = 15
LR = 0.0009


model_root = pl.Path(MODEL_NAME)
model_root.mkdir(exist_ok=True)



optimizer = optim.Adam(model.parameters(), lr=LR)

# implement loss explicitly
loss = nn.MSELoss()

# train the model
history = {'loss': [], 'val_loss': [], 'epoch': []}
sample_history = {}

pbar = tqdm.tqdm(range(N_EPOCHS), postfix=f'epoch 0/{N_EPOCHS}')
for epoch_idx in pbar:
    epoch_loss = 0
    model.train()
    for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(noisy_data)
        loss_value = loss(output, data)
        loss_value.backward()
        optimizer.step()
        epoch_loss += loss_value.detach().cpu().item()
    epoch_loss /= len(train_loader)
    history['loss'].append(epoch_loss)
    history['epoch'].append(epoch_idx)
    # update progress bar

    # evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
            output = model(noisy_data)
            loss_value = loss(output, data)
            val_loss += loss_value.detach().cpu().item()
        val_loss /= len(valid_loader)
        history['val_loss'].append(val_loss)

    pbar.set_postfix({'epoch': f'{epoch_idx+1}/{N_EPOCHS}', 'loss':f'{epoch_loss:.4f}', 'val_loss':f'{val_loss:.4f}'})
    # evaluate on samples
    sample_res = eval_on_samples(model, epoch_idx, samples=samples)
    sample_history[epoch_idx] = sample_res

    # save model weights
    torch.save({
                'epoch': epoch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, model_root/f'model_{epoch_idx:03d}.pth')

Plot loss function evolution during training:

In [None]:
plot_hist(history)

Visualise evolution of reconstruction:

In [None]:
plot_samples(sample_history, samples=samples, epoch_stride=5, fig_scale=1)

## How to access the latent space (conversation with Claude):

The latent space is already saved in the sample history (return_z parameter in the model application (part of the eval_on_samples function) is set to True). 

Alternatively, one could access the latent space by applying the trained model again.

### 3 Methods to produce the latent space using the trained model (according to Claude): 

#### 1. Using the encode() method directly:

In [None]:
# For a batch of images
with torch.no_grad():
    encoded_batch = model.encode(image_batch)  # image_batch should be on the same device as the model

# Convert to numpy for visualization if needed
encoded_batch_np = encoded_batch.detach().cpu().numpy()

#### 2. Using the forward() method with return_z=True:

In [None]:
with torch.no_grad():
    reconstructed, encoded = model(image_batch, return_z=True)
    
encoded_np = encoded.detach().cpu().numpy()

#### 3. For the entire validation set:

In [None]:
# Get encoded representations for all validation data
all_encoded = []
all_labels = []

model.eval()
with torch.no_grad():
    for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
        encoded = model.encode(data)  # Using clean images, not noisy ones
        
        all_encoded.append(encoded.detach().cpu().numpy())
        all_labels.append(target.detach().cpu().numpy())
        
all_encoded = np.concatenate(all_encoded, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

## Exercise 1

Use autoencoder to find outliers:
1. preform reconstruciton for validation dataset
2. evaluate reconstruction loss and plot distrubution ()
3. visualise poorely reconstructed samples

Extra/Homework:
Perform 2D UMAP embedding of the raw data points and image encoding (variables in latent space) for validation set, plot it and highlight outliers.

# Convolutional autoencoder

Instead of fully connected layers we can use strided convolutional layers in encoder, and transposed convolutions in decoder.
This model will have less parameters due to the weight sharing, thus easier to train.

After upscaling the image size will be a bit bigger then original, so we also crop reconstruction to the input image size.

In [None]:
class ConvolutionalAutoEncoder(AutoEncoder):
# This class inherits from the AutoEncoder base class we discussed earlier, 
# extending it to use convolutional layers instead of fully connected layers.
    def __init__(self, input_size, code_size):
        self.input_size = list(input_size)  # shape of data sample

        self.hidden_size = 32*2*2

        self.code_size = code_size  # code size

        super(ConvolutionalAutoEncoder, self).__init__(input_size, code_size)
        # Initializes with the same parameters as the base class
        # Sets hidden_size to 128 (32×2×2), which will be the size of the flattened representation before the final encoding
        # Calls the parent class initializer, but will override the encoder and decoder definitions
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1,   8, 3, padding=1, stride=1), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(8,   8, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(8,  16, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(16, 16, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),

            # Defines a series of convolutional layers for the encoder
            # Each layer increases the number of feature maps while reducing spatial dimensions
            # Starting with 1 channel (grayscale input), increasing to 8, then 16, then 32 channels
            # Uses padding=1 to maintain spatial dimensions before downsampling
            # Uses stride=2 in most layers to perform downsampling (reducing spatial dimensions by half)
            # Each convolution is followed by LeakyReLU activation with negative_slope=0.3

            nn.Flatten(),

            nn.Linear(self.hidden_size, self.hidden_size//8), nn.LeakyReLU(negative_slope=0.3),
            nn.Linear(self.hidden_size//8, self.code_size),
            # nn.Tanh(),
            # After the convolutional layers, flattens the 3D feature maps into a 1D vector
            # Applies two fully connected layers to reduce dimensions to the final code_size
            # The final Tanh activation is commented out
            
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.code_size, self.hidden_size), nn.LeakyReLU(negative_slope=0.3),

            nn.Unflatten(1, (32, 2, 2)),
            # The decoder starts with a fully connected layer to expand from code_size to hidden_size
            # Reshapes the 1D vector back to a 3D feature map with shape (32, 2, 2)

            nn.ConvTranspose2d(32, 16, 3, padding=1, output_padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.ConvTranspose2d(16, 16, 3, padding=1, output_padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.ConvTranspose2d(16,  8, 3, padding=1, output_padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.ConvTranspose2d(8,   8, 3, padding=1, output_padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),

            # Uses transposed convolutions (also called deconvolutions) to upsample the feature maps
            # Mirrors the encoder structure in reverse, gradually decreasing channels and increasing spatial dimensions
            # output_padding=1 helps ensure the spatial dimensions match the original input after upsampling
            
            nn.Conv2d(8, 1, 3, padding=1, stride=1), nn.Tanh(),
            # Final convolution layer produces a single channel output (grayscale image)
            # Tanh activation constrains the output values to [-1, 1] range
        )

    def decode(self, z):
        reconstruction = self.decoder(z)
        reconstruction = reconstruction[:, :, 2:-2, 2:-2]
        return reconstruction

        # Overrides the parent class's decode method
        # Applies the decoder to the latent representation
        # Crops the reconstructed image by removing 2 pixels from each side
        # This cropping likely compensates for any dimension mismatches caused by the convolution/deconvolution operations

# This convolutional architecture is much more powerful for image data than the fully connected version, as it:
# 
# 1. Preserves spatial relationships in the data
# 2. Uses parameter sharing for efficiency
# 3. Can learn hierarchical features (edges, textures, patterns)
# 4. Generally results in better reconstructions and more meaningful latent representations for image data
# 
# The progression from a simple fully connected autoencoder to this convolutional version is a common step in improving deep learning models for image data.


Thus we can try to reduce sise of the latent space:

In [None]:
CODE_SIZE = 2
NOISE_RATE = 0.2
MODEL_NAME = 'cdae_model'
model = ConvolutionalAutoEncoder(input_size=in_size, code_size=CODE_SIZE).to(device)

samples = get_samples(valid_loader)

In [None]:
xns = torch.tensor(samples['images_noisy']).to(device)
print(xns.shape)
zs = model.encode(xns)
ys = model(xns)
print(zs.shape)
print(ys.shape)

In [None]:
model

In [None]:
model.get_n_params()

In [None]:
history = {'loss': [], 'val_loss': [], 'epoch': []}
sample_history = {}

Train the model for 70 epochs

In [None]:
# train the autoencoder model, for N_EPOCHS epochs,
# save history of loss values for training and validation sets,
# history of validation samples evolution, and model weights history,

N_EPOCHS = 70
LR = 0.0004


model_root = pl.Path(MODEL_NAME)
model_root.mkdir(exist_ok=True)



optimizer = optim.Adam(model.parameters(), lr=LR)

# implement loss explicitly
loss = nn.L1Loss()

# train the model
#history = {'loss': [], 'val_loss': [], 'epoch': []}
#sample_history = {}

pbar = tqdm.tqdm(range(0, N_EPOCHS), postfix=f'epoch 0/{N_EPOCHS}')
for epoch_idx in pbar:
    epoch_loss = 0
    model.train()
    for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(noisy_data)
        loss_value = loss(output, data)
        loss_value.backward()
        optimizer.step()
        epoch_loss += loss_value.detach().cpu().item()
    epoch_loss /= len(train_loader)
    history['loss'].append(epoch_loss)
    history['epoch'].append(epoch_idx)
    # update progress bar

    # evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
            output = model(noisy_data)
            loss_value = loss(output, data)
            val_loss += loss_value.detach().cpu().item()
        val_loss /= len(valid_loader)
        history['val_loss'].append(val_loss)

    pbar.set_postfix({'epoch': f'{epoch_idx+1}/{N_EPOCHS}', 'loss':f'{epoch_loss:.4f}', 'val_loss':f'{val_loss:.4f}'})
    # evaluate on samples
    sample_res = eval_on_samples(model, epoch_idx, samples=samples)
    sample_history[epoch_idx] = sample_res

    # save model weights
    torch.save({
                'epoch': epoch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, model_root/f'model_{epoch_idx:03d}.pth')

Plot loss function evolution during training:

In [None]:
plot_hist(history)

Visualise evolution of reconstruction:

In [None]:
plt.hist(samples['images'].flatten(), bins=100, log=True);
for k in ['y', 'z']:
    #print(sample_history[0][k].shape)
    plt.hist(np.array(sample_history[0][k]).flatten(), bins=100, log=True, alpha=0.3);
    plt.hist(np.array(sample_history[epoch_idx-1][k]).flatten(), bins=100, log=True, alpha=0.3);
    plt.show()
    plt.close()



In [None]:
plot_samples(sample_history, samples=samples, epoch_stride=5, fig_scale=1)

## Latent space

So far we looked just on the output: the model sort of does the job. But what does it learn?

First let's use animation to visualize reconstruction evolution:

In [None]:
%%capture
# Jupyter notebook cell magic to suppress output (useful for code that might produce verbose output)

single_el_idx = samples['single_el_idx']
images_noisy = samples['images_noisy'][single_el_idx, 0]
images = samples['images'][single_el_idx, 0]
# Extracts the indices for the selected sample images
# Gets the noisy input images and original clean images for these indices
# The , 0 selects the first channel (since these are grayscale images)

smpl_ims = []
for epoch_idx, hist_el in sample_history.items():
    samples_arr = [images_noisy, hist_el['y'][single_el_idx, 0], images]
    smpl_ims.append(samples_arr)
# Creates a list to store image arrays for each epoch
# For each epoch in the training history:
    # 
    # Creates an array containing [noisy inputs, reconstructions, original images]
    # Adds this array to the list

ny, nx = len(smpl_ims[0]), len(smpl_ims[0][0])
# Determines the number of rows (3: noisy, reconstructed, original) and columns (number of samples)

plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above animations use JavaScript
# Sets matplotlib to use JavaScript for HTML animations in Jupyter

s=1
fig = plt.figure(figsize=(s*nx, s*ny))
# Creates a figure with size proportional to the number of images

m = mosaic(smpl_ims[0])

ttl = plt.title(f'after epoch {int(0)}')
# plot 0th epoch - 0th frame
imsh = plt.imshow(m, cmap='gray', vmin=-0.5, vmax=0.5)
# Creates the initial frame of the animation using the first epoch's images
# Uses the mosaic function to arrange the images in a grid
# Sets grayscale colormap with value range [-0.5, 0.5]

# this function will be called to render each of the frames
def animate(i):
    m = mosaic(smpl_ims[i])
    imsh.set_data(m)

    ttl.set_text(f'after epoch {i}')

    return imsh

# Defines a function to update the plot for each frame of the animation
# Creates a mosaic of images for the current epoch
# Updates the image data and title text
# Returns the updated image object

# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(smpl_ims))

# Creates an animation that calls the animate function for each epoch
# The result is a dynamic visualization showing how reconstructions evolve throughout training

# This animation provides an intuitive way to observe the autoencoder's learning 
# progress, allowing you to see how the model gradually improves at reconstructing 
# the original images from the noisy inputs.



In [None]:
# display animation
ani

And let's see evolution of the latent representations:

In [None]:
%%capture
plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above, uses JavaScript
fig = plt.figure(figsize=(8,8))
# Suppresses output with %%capture
# Sets matplotlib to use JavaScript for HTML animations
# Creates a square figure with size 8×8 inches

labels = samples['labels']
epochs = sorted(sample_history.keys())
z_res = [sample_history[ep]['z'][0] for ep in epochs]
# Gets the digit labels from the samples dictionary
# Creates a sorted list of all epoch numbers
# Extracts the latent space representations from each epoch

scat = plt.scatter(z_res[0][:,0], z_res[0][:,1], c=labels, cmap=cm.rainbow)
# Creates a scatter plot using the first two dimensions of the latent space from the first epoch
# Colors the points according to their digit labels (0-9)
# Uses the rainbow colormap to distinguish between different digits

plt.xlim(-6.1, 6.1)
plt.ylim(-6.1, 6.1)

ax = plt.gca()
legend1 = ax.legend(*scat.legend_elements(), title="digits")
ax.add_artist(legend1)
ax.set_aspect('equal')
ttl = plt.title(f'after epoch {0}')
# Sets fixed axis limits for consistent visualization across frames
# Gets the current axis
# Adds a legend showing the mapping between colors and digit classes
# Sets the aspect ratio to equal so circles appear as circles
# Adds a title showing the current epoch

def animate(i):
    z = z_res[i]
    scat.set_offsets(z)
    ttl.set_text(f'after epoch {i}')
    return scat

# Defines a function to update the plot for each animation frame
# Updates the scatter plot with the latent representations from the current epoch
# Updates the title text with the current epoch number
# Returns the updated scatter plot object

ani = animation.FuncAnimation(fig, animate, frames=len(z_res))

# Creates an animation that runs through all epochs

# This animation shows how the model progressively learns to organize 
# the latent space, with points representing the same digit class gradually 
# clustering together. It's a powerful visualization that helps understand 
# how the autoencoder is learning meaningful representations and 
# separating different classes in the latent space, even though it's 
# trained in an unsupervised manner without using the labels for training.


In [None]:
ani

## Sampling from latent space

But we can also use the trained model to generate samples based on the latent representation:

In [None]:
# This code generates and visualizes images from random points in the latent space 
# across different training epochs. Here's the breakdown:

z0 = np.random.uniform(size=(25, CODE_SIZE))*12-6  # get 20 random points in 2D sampled from uniform distribution between 0 and 1
# Creates 25 random points in the latent space
# Scales the values to the range [-6, 6] (multiplies by 12 and subtracts 6)
# This creates a uniform distribution of points across the latent space

z0_t = torch.tensor(z0, dtype=torch.float32).to(device)
# Converts the random points to a PyTorch tensor
# Moves the tensor to the appropriate device (CPU or GPU)

ims_all = []
# Creates an empty list to store generated images from each epoch

# this function will be called in saved model state after each training epoch
def fn(ae):
  with torch.no_grad():
    ims = ae.decode(z0_t)
    ims = ims.detach().cpu().numpy()
    ims_all.append(ims)
# Defines a function to be called on the model checkpoint from each epoch
# Uses torch.no_grad() to disable gradient computation (for efficiency)
# Decodes the random latent vectors to generate images
# Converts the resulting images to NumPy arrays
# Adds the batch of images to the collection

run_on_all_training_history(model, model_root, fn)
# Calls the run_on_all_training_history function we saw earlier
# This runs the fn function on the model at each saved epoch
# The function generates images from the same random latent points using 
# each epoch's model state


ims_all = np.array(ims_all)
print(ims_all.shape)
# Converts the list of image batches to a NumPy array
# Prints the shape of the array, likely [n_epochs, 25, 1, height, width]

    # n_epochs: Number of training epochs
    # 25: Number of random latent vectors
    # 1: Number of channels (grayscale)
    # height, width: Image dimensions

ims_all = ims_all[:, :, 0, :, :]  # remove channel dimension
# Removes the channel dimension from the array
# This simplifies the array shape to [n_epochs, 25, height, width]
# 
# This code allows you to visualize how the model's ability to generate images 
# from latent space evolves throughout training. It also helps understand 
# what different regions of the latent space represent. 
# Since you're using the same random points across all epochs, 
# you can see how the model's interpretation of those points changes as it learns.



In [None]:
plt.hist(ims_all[0].flatten(), bins=100, log=True);
plt.hist(ims_all[-1].flatten(), bins=100, log=True, alpha=0.5);


In [None]:

# This code creates an animation showing how the decoder's ability to generate 
# images from random latent vectors evolves over training epochs:

%%capture
fig = plt.figure(figsize=(20, 1.5))
# Suppresses output with %%capture
# Creates a wide, short figure to display the generated images in a row

m = mosaic([ims_all[0]])
# Uses the mosaic function to arrange the first epoch's images in a grid
# Since ims_all[0] contains all 25 images from epoch 0, this creates 
# a single row of images

imsh = plt.imshow(m, cmap='gray', vmin=0, vmax=1)
ttl = plt.title(f'after epoch {int(0)}')
# Displays the mosaic with a grayscale colormap
# Sets the value range to [0, 1] for proper contrast
# Adds a title showing the current epoch

def animate(i):
    m = mosaic([ims_all[i]])
    imsh.set_data(m)
    ttl.set_text(f'after epoch {i}')
    return imsh
# Defines a function to update the plot for each animation frame
# Creates a mosaic of the generated images from the current epoch
# Updates the image data and title
# Returns the updated image object

ani = animation.FuncAnimation(fig, animate, frames=len(ims_all))

# Creates an animation that cycles through all epochs
# 
# This visualization shows how the model learns to map points in the latent space 
# to meaningful images over time. Initially, the generated images may look 
# like noise or very basic patterns. As training progresses, they should gradually 
# transform into more recognizable digit-like shapes.
# What makes this particularly interesting is that these images are generated 
# from random points in the latent space, not from encoded real images. 
# This demonstrates the generative capability of the autoencoder - its ability 
# to create new, plausible examples by sampling from the latent space.



In [None]:
ani

## Interpolation in latent space

In [None]:
# This code creates a regular grid of points in a 2D latent space for visualization 
# purposes. Here's what it does:

n_samples = 10
# Sets the number of samples per dimension to 10, which will create a 10×10 grid

zs = np.meshgrid(np.linspace(-6, 6, n_samples),
                 np.linspace(-6, 6, n_samples))
# Creates two coordinate arrays using np.meshgrid
# Each array represents one dimension of the latent space
# np.linspace(-6, 6, n_samples) creates 10 equally spaced points from -6 to 6
# Together, these form a regular 10×10 grid of coordinates in 2D space

zs = np.stack(zs, axis=-1).reshape(-1, 2)
# np.stack(zs, axis=-1) combines the two coordinate arrays into a single 3D array with shape [10, 10, 2]
# .reshape(-1, 2) flattens the grid into a list of 100 points (10×10) with 2 coordinates each
# The result is a 2D array with shape [100, 2]

zs_t = torch.tensor(zs, dtype=torch.float32).to(device)
# Converts the NumPy array to a PyTorch tensor
# Ensures the data type is float32
# Moves the tensor to the appropriate device (CPU or GPU)
# 
# This grid of points is useful for creating a visualization of the 
# latent space manifold - by decoding each point in this regular grid, 
# you can see how the latent space is organized and how different regions 
# correspond to different types of generated images. This is commonly used 
# to create a "latent space walk" or "manifold visualization" that shows 
# how the generated images change as you move through the latent space in a structured way.



In [None]:
def fn(ae):
  with torch.no_grad():
    ims_t = ae.decode(zs_t)
    ims = ims_t.detach().cpu().numpy()
    ims = ims[:, 0, :, :]  # remove channel dimension

  sh = list(ims.shape)
  ims = ims.reshape([n_samples, n_samples]+sh[1:])
  plt.figure(figsize=(10, 10))

  # here image order is reversed along y axis in `ims` to match y axis direction on the plot
  # (y axis goes from bottom to top, while pixel order - from top to bottom)
  plt.imshow(mosaic(ims[::-1]), vmin=0, vmax=1, cmap='gray')
  plt.show()

run_on_trained(model, model_root, fn)


## Exercise 2


Work in 2 groups:


Option 1. Compare distribution in latent space for noise rate 0, 0.2, 0.6, 0.8.

Option 2. Compare reconstruction depending on size of latent space: 2, 4, 8, 16, 32

## Exercise 3: Writing latent Dataset


In this exercise, we will
1. modify the CNN to be fully convolutional, with 4 downscaling layers (x16), 16 channels
2. save dataset of latent states, see code below for example.

When usuing VAE - save the Z nean and log std - separately

In [None]:
import pickle

def save_pckl(obj, fname):
    with open(fname, 'wb') as f:
        pickle.dump(obj, f, pickle.DEFAULT_PROTOCOL)

def load_pckl(fname):
    with open(fname, 'rb') as f:
        obj = pickle.load(f)
    return obj

In [None]:
#load latest model checkpoint:
run_on_trained(model, model_root, lambda m:None, ep=3)

# dataset for latent DDM, file list:
data_files = {"train": [], "valid": []}

ds_root = pl.Path('data')
ds_root.mkdir(exist_ok=True)

# generate latent values per blocks of samples:
block_sz = 5000
for ds, sfx in zip([valid_dataset, train_dataset], ['valid', 'train']):
  x = ds.data.numpy().reshape(-1, 1, 28, 28)
  x = x/255.*2-1
  l = ds.targets.numpy()

  for i in range(0, len(x), block_sz):
    x_b = x[i:i+block_sz]
    l_b = l[i:i+block_sz]

    x_b_t = torch.tensor(x_b, dtype=torch.float32).to(device)

    with torch.no_grad():
        z_t = model.encode(x_b_t)
        z_b = z_t.detach().cpu().numpy()

    # save latent values and their shape and labels to a dataframe:
    data_d = {
    'z': [zi.flatten() for zi in z_b],
    'shape': [zi.shape for zi in z_b],
    'label': l_b
    }

    df = pd.DataFrame(data=data_d)

    # save dataframe to a file and reference dictioanry:
    fname = ds_root/f'df_z_{sfx}_{i}.pckl'
    save_pckl(df, fname)
    data_files[sfx].append(str(fname))

save_pckl(data_files, ds_root/'data_files.pkl')

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset, Dataset

In [None]:
# now we can load the dataset:
data_files = load_pckl(ds_root/'data_files.pkl')
lds = load_dataset('pandas', data_files=data_files)

In [None]:
# convert to torch dataset:
tds = lds.with_format("torch")

In [None]:
tds['train']['z'].shape

# VAE

While the previous model learn to seaparate subpopulations in the latent space, there remains significant overlap & unpopulated regions.

Varitional AE puts additional constraints on the distribution in the latent space and perform variational inference.

(see pptx for details)

Here the encoder will predict `2 x n_code` values: means and logarithm of variance for each sample. Since these values live in $R^2$ - no activation function is used in last layer of the encoder.

Then for reconstruction we will sample from this distribuition with a reparametrisation trick.

The tecnically complex part - is to implement the custom loss function and training loop.

In [None]:
class VariationalConvolutionalAutoencoder(AutoEncoder):
    def __init__(self, input_size, code_size):
        super(VariationalConvolutionalAutoencoder, self).__init__(input_size, code_size)

        # nn.LeakyReLU(negative_slope=0.3)
        self.input_size = list(input_size)  # shape of data sample
        self.npix = np.prod(self.input_size)

        self.hidden_size = 64*1

        self.code_size = code_size  # code size

        self.encoder = nn.Sequential(
            nn.Conv2d(1,   16, 3, padding=1, stride=1), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(16,  16, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(16,  32, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(32,  32, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(32,  32, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),
            nn.Conv2d(32,  64, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.3),

            nn.Flatten(),

            nn.Linear(self.hidden_size, 64), nn.LeakyReLU(negative_slope=0.3),

            # we want values in all R, thus no activation function is applied. self.n_code values for mean + self.n_code for log(variance)
            nn.Linear(64, self.code_size * 2),

        )
        self.decoder = nn.Sequential(
            nn.Linear(self.code_size, 64), nn.LeakyReLU(negative_slope=0.3),
            nn.Linear(64, self.hidden_size), nn.LeakyReLU(negative_slope=0.3),

            nn.Unflatten(1, (64, 1, 1)),

            nn.ConvTranspose2d(64, 32, 3, padding=1, output_padding=1, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 3, padding=1, output_padding=1, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 3, padding=1, output_padding=1, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, padding=1, output_padding=1, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(16, 16, 3, padding=1, output_padding=1, stride=2), nn.ReLU(),
            nn.Conv2d(16, 1, 3, padding=1, stride=1), nn.Tanh(),
        )

    def sample(self, eps=None):
        if eps is None:
            # samle from standard normal distribution
            eps = torch.randn((100, self.code_size))
        return self.decode(eps)

    def encode(self, x):
        z = self.encoder(x)
        z_mean, z_logvar = torch.split(z, split_size_or_sections=2, dim=1)
        return z_mean, z_logvar

    def reparameterize(self, z_mean, z_logvar):
        # reaparametrization trick: to sample z from N(mean, std):
        # z = mean + std * eps, where eps sampled from N(0, 1)
        eps = torch.randn_like(z_mean)
        z_std = torch.exp(z_logvar * .5)
        return eps * z_std + z_mean

    def decode(self, z):
        reconstruction = self.decoder(z)
        reconstruction = reconstruction[:, :, 2:-2, 2:-2]
        return reconstruction

    def forward(self, x, return_z=False):
        z_mean, z_logvar = self.encode(x)
        z = self.reparameterize(z_mean, z_logvar)
        reconstruction = self.decode(z)
        return (reconstruction, z_mean, z_logvar) if return_z else reconstruction

    def forward_and_KL_loss(self, x, y):
        reconstruction, z_mean, z_logvar = self(x, return_z=True)

        # tf impl:
        # loss_z_kl = 0.5 * tf.reduce_sum(tf.exp(z_logvar) + tf.square(z_mean) -1.0 - z_logvar, axis=1)  # KL divergence from N(0, 1) to N(z_mean, tf.exp(z_logvar * .5))
        # loss_z_kl = tf.reduce_mean(loss_z_kl) / np.prod(self.data_size)  # but since we used mean in reconstruction loss - this term has to be normalized accordingly

        # pytorch impl:
        loss_z_kl = 0.5 * torch.sum(torch.exp(z_logvar) + torch.square(z_mean) -1.0 - z_logvar, dim=1)
        loss_z_kl = torch.mean(loss_z_kl) / self.npix

        return reconstruction, loss_z_kl


In [None]:
CODE_SIZE = 2 # 50
NOISE_RATE = 0.
MODEL_NAME = 'vcae_model'
model = VariationalConvolutionalAutoencoder(input_size=in_size, code_size=CODE_SIZE).to(device)

samples = get_samples(valid_loader)

In [None]:
xns = torch.tensor(samples['images_noisy']).to(device)
print(xns.shape)
ys = model(xns)
ys.shape

In [None]:
model

In [None]:
model.get_n_params()

In [None]:
# train the autoencoder model, for N_EPOCHS epochs,
# save history of loss values for training and validation sets,
# history of validation samples evolution, and model weights history,

N_EPOCHS = 20 # 20
LR = 0.0009


model_root = pl.Path(MODEL_NAME)
model_root.mkdir(exist_ok=True)



optimizer = optim.Adam(model.parameters(), lr=LR)

# implement loss explicitly
loss = nn.MSELoss()

# train the model
history = {'loss': [], 'val_loss': [], 'rec_loss': [], 'rec_val_loss': [], 'kl_loss': [], 'kl_val_loss': [], 'epoch': []}
sample_history = {}

pbar = tqdm.tqdm(range(N_EPOCHS), postfix=f'epoch 0/{N_EPOCHS}')
for epoch_idx in pbar:
    epoch_loss = 0
    epoch_rec_loss = 0
    epoch_kl_loss = 0

    model.train()
    for batch_idx, (noisy_data, data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        #output = model(noisy_data)
        output, kl_loss = model.forward_and_KL_loss(noisy_data, data)
        rec_loss = loss(output, data)
        loss_value = rec_loss + kl_loss
        loss_value.backward()
        optimizer.step()

        epoch_loss += loss_value.detach().cpu().item()
        epoch_rec_loss += rec_loss.detach().cpu().item()
        epoch_kl_loss += kl_loss.detach().cpu().item()

    n_elements = len(train_loader)

    epoch_loss /= n_elements
    epoch_rec_loss /= n_elements
    epoch_kl_loss /= n_elements

    history['loss'].append(epoch_loss)
    history['rec_loss'].append(epoch_rec_loss)
    history['kl_loss'].append(epoch_kl_loss)

    history['epoch'].append(epoch_idx)
    # update progress bar

    # evaluate on validation set
    model.eval()
    with torch.no_grad():
        val_loss = 0
        val_rec_loss = 0
        val_kl_loss = 0

        for batch_idx, (noisy_data, data, target) in enumerate(valid_loader):
            #output = model(noisy_data)
            output, kl_loss = model.forward_and_KL_loss(noisy_data, data)
            rec_loss = loss(output, data)
            loss_value = rec_loss + kl_loss

            val_loss += loss_value.detach().cpu().item()
            val_rec_loss += rec_loss.detach().cpu().item()
            val_kl_loss += kl_loss.detach().cpu().item()

        val_loss /= len(valid_loader)
        val_rec_loss /= len(valid_loader)
        val_kl_loss /= len(valid_loader)

        history['val_loss'].append(val_loss)
        history['rec_val_loss'].append(val_rec_loss)
        history['kl_val_loss'].append(val_kl_loss)

    pbar.set_postfix({'epoch': f'{epoch_idx+1}/{N_EPOCHS}', 'loss':f'{epoch_loss:.4f}', 'val_loss':f'{val_loss:.4f}'})
    # evaluate on samples
    sample_res = eval_on_samples(model, epoch_idx, samples=samples)
    sample_history[epoch_idx] = sample_res

    # save model weights
    torch.save({
                'epoch': epoch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, model_root/f'model_{epoch_idx:03d}.pth')

In [None]:
model

In [None]:
plot_hist(history)

In [None]:
plot_samples(sample_history, samples=samples, epoch_stride=5, fig_scale=1)

Visualize reconstruciton:

In [None]:
%%capture

single_el_idx = samples['single_el_idx']
images_noisy = samples['images_noisy'][single_el_idx, 0]
images = samples['images'][single_el_idx, 0]

smpl_ims = []
for epoch_idx, hist_el in sample_history.items():
    samples_arr = [images_noisy, hist_el['y'][single_el_idx, 0], images]
    smpl_ims.append(samples_arr)

ny, nx = len(smpl_ims[0]), len(smpl_ims[0][0])

plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above animations use JavaScript

s=1
fig = plt.figure(figsize=(s*nx, s*ny))

m = mosaic(smpl_ims[0])

ttl = plt.title(f'after epoch {int(0)}')
# plot 0th epoch - 0th frame
imsh = plt.imshow(m, cmap='gray', vmin=0, vmax=1)

# this function will be called to render each of the frames
def animate(i):
    m = mosaic(smpl_ims[i])
    imsh.set_data(m)

    ttl.set_text(f'after epoch {i}')

    return imsh

# create animation
ani = animation.FuncAnimation(fig, animate, frames=len(smpl_ims))

In [None]:
ani

Let's look at the latent representation:

In [None]:
sample_history[0]['z'][0].shape

In [None]:
%%capture
plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above, uses JavaScript
fig = plt.figure(figsize=(8,8))

smpl_epochs = list(sample_history.keys())

z_res_mean = [sample_history[ep]['z'][0] for ep in smpl_epochs]
z_logvar = [sample_history[ep]['z'][1] for ep in smpl_epochs]
z_res_std = np.exp(np.array(z_logvar) * .5)


l_res = samples['labels']
l_col = [cm.rainbow(l_i/10) for l_i in l_res]

# error bars:
def get_lines(m, s):
  l = []
  for (x, y), (sx, sy) in zip(m, s):
    l.append([(x-sx, y), (x+sx, y)]) # h
    l.append([(x, y-sy), (x, y+sy)]) # w
  return np.array(l)

lines = get_lines(z_res_mean[0], z_res_std[0])
lc = mc.LineCollection(lines, color=l_col, linewidths=2, alpha=0.3)
fig.gca().add_collection(lc)

scat = plt.scatter(z_res_mean[0][:,0], z_res_mean[0][:,1], c=l_res, cmap=cm.rainbow)
#scat_err = plt.errorbar(z_res_mean[0][:,0], z_res_mean[0][:,1], xerr=z_res_std[0][:,0], yerr=z_res_std[0][:,1], fmt="o")

plt.xlim(-4, 4)
plt.ylim(-4, 4)

legend1 = plt.gca().legend(*scat.legend_elements(), title="digits")
plt.gca().add_artist(legend1)
plt.gca().set_aspect('equal')
ttl = plt.title(f'after epoch {0}')

def animate(i):
    z = z_res_mean[i]
    scat.set_offsets(z)

    lines = get_lines(z_res_mean[i], z_res_std[i])
    lc.set_segments(lines)

    ttl.set_text(f'after epoch {i}')
    return scat, lc

ani = animation.FuncAnimation(fig, animate, frames=len(z_res_mean))

In [None]:
ani

And let's sample from the latent space:

In [None]:
zs.shape

In [None]:
n = 20
zs = np.meshgrid(norm.ppf(np.linspace(0.0001, 0.9999, n)),  # sampling proportionally to the density
                 norm.ppf(np.linspace(0.0001, 0.9999, n)))
# zs = np.meshgrid(np.linspace(-4, 4, n),  # sampling linearly
#                  np.linspace(-4, 4, n))
zs = np.stack(zs, axis=-1)
zs = zs.reshape((-1, 2))

zs_t = torch.tensor(zs, dtype=torch.float32).to(device)

def fn(ae):
  with torch.no_grad():
    ims_t = ae.decode(zs_t)
    ims = ims_t.detach().cpu().numpy()
    ims = ims[:, 0, :, :]  # remove channel dimension

  sh = list(ims.shape)
  ims = ims.reshape([n, n]+sh[1:])
  plt.figure(figsize=(n, n))
  plt.imshow(mosaic(ims[::-1]), vmin=0, vmax=1, cmap='gray')
  plt.show()
  plt.close()

run_on_trained(model, model_root, fn)


In [None]:
def fn(ae):
  # x = samples['images_noisy']
  # l = samples['labels']
  # ds = valid_dataset
  ds = train_dataset
  x = ds.data.numpy().reshape(-1, 1, 28, 28)
  x = x/255.*2-1
  l = ds.targets.numpy()
  x_t = torch.tensor(x, dtype=torch.float32).to(device)
  z_m, z_s = ae.encode(x_t)
  z_m, z_s = [el.detach().cpu().numpy() for el in [z_m, z_s]]

  z_res_std = np.exp(np.array(z_s) * .5)

  plt.figure(figsize=(10, 10))
  scat = plt.scatter(*z_m.T, c=l, s=2, cmap='jet')
  plt.gca().legend(*scat.legend_elements(), title="digits")
  plt.show()
  plt.close()

  fig, ax = plt.subplots(figsize=(10, 2), ncols=2, nrows=1)
  ax[0].hist(z_m.flatten(), bins=100);
  ax[1].hist(z_res_std.flatten(), bins=100);
  ax[0].set_title('mean')
  ax[1].set_title('std')
  plt.show()
  plt.close()

run_on_trained(model, model_root, fn)


## Visualization excercise (homework)

Option 1: Sample 31x31 and overlay with data points.

Option 2: Visulize data-point in latent space


```
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

plt.figure(figsize=(10,10))
ax = plt.gca()

# given the image sample array `image_arr` and the pair of coordinate arrays
# of the latent representation z0_arr, z1_arr:
for z0_i, z1_i, im_i in zip(z0_arr, z1_arr, image_arr):
    im = OffsetImage(im_i, zoom=0.5)
    ab = AnnotationBbox(im, (z0_i, z1_i), xycoords='data', frameon=False)

    ax.add_artist(ab)
    ax.update_datalim([(z0_i, z1_i)])
    ax.autoscale()

```

## Example of project work

Use VAE to explore Fashion MNIST, or any other dataset
 1. train model
 2. explore latent representation
 3. find outliers
 4. sample from the latent distribution
 5. overlay with data points
 6. Train classifier/regressor using the latent space representation of the samples