The Goal of this Notebook is to understand the working of Normalizing flows by applying them on MNIST dataset.

Imports 

In [7]:
import os
import torch
import numpy as np
import pandas as pd
import imageio
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

In [8]:
available = torch.cuda.is_available()
device = torch.device("cuda:0" if available else "cpu")
print(f'Using device: {device}')
# (Adapted) Code from PyTorch's Resnet impl: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
# we are using resnet for s and t 

# Defining conv3*3 and conv1*1 functions 
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

# Basic Building Block
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            # norm_layer = nn.BatchNorm2d
            norm_layer = nn.InstanceNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

Using device: cuda:0


In [9]:
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
    expansion: int = 1
    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            #norm_layer = nn.BatchNorm2d
            norm_layer = nn.InstanceNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)

        return out

In [10]:
class BatchNorm2d(nn.modules.batchnorm._NormBase):
    ''' Partially based on: 
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
        https://discuss.pytorch.org/t/implementing-batchnorm-in-pytorch-problem-with-updating-self-running-mean-and-self-running-var/49314/5 
    '''
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.005,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype, 'affine': False, 'track_running_stats': True}
        super(BatchNorm2d, self).__init__(
            num_features, eps, momentum, **factory_kwargs
        )
        
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

    def forward(self, input, validation=False):
        self._check_input_dim(input)
    
        if self.training:
            unbiased_var, tmean = torch.var_mean(input, [0, 2, 3], unbiased=True)
            mean = torch.mean(input, [0, 2, 3]) # along channel axis
            unbiased_var = torch.var(input, [0, 2, 3], unbiased=True) # along channel axis
            running_mean = (1.0 - self.momentum) * self.running_mean.detach() + self.momentum * mean
            running_var = (1.0 - self.momentum) * self.running_var.detach() + self.momentum * unbiased_var
            current_mean = running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = running_var.view([1, self.num_features, 1, 1]).expand_as(input)
            
            denom = (current_var + self.eps)
            y = (input - current_mean) / denom.sqrt()
            
            self.running_mean = running_mean
            self.running_var = running_var
            
            return y, -0.5 * torch.log(denom)
        else:
            current_mean = self.running_mean.view([1, self.num_features, 1, 1]).expand_as(input)
            current_var = self.running_var.view([1, self.num_features, 1, 1]).expand_as(input)
            
            if validation:
                denom = (current_var + self.eps)
                y = (input - current_mean) / denom.sqrt()
            else:
                # Reverse operation for testing
                denom = (current_var + self.eps)
                y = input * denom.sqrt() + current_mean
                
            return y, -0.5 * torch.log(denom)

class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = tuple([-1] + list(shape))
        
    def forward(self, x):
        return torch.reshape(x, self.shape)

def dense_backbone(shape, network_width):
    input_width = shape[0] * shape[1] * shape[2]
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(input_width, network_width),
        nn.ReLU(),
        nn.Linear(network_width, input_width),
        Reshape(shape)
    )

def bottleneck_backbone(in_planes, planes):
    return nn.Sequential(
        conv3x3(in_planes, planes),
        BasicBlock(planes, planes),
        BasicBlock(planes, planes),
        conv3x3(planes, in_planes),
    )

check_mask = {}
check_mask_device = {}
def checkerboard_mask(shape, to_device=True):
    global check_mask, check_mask_device
    if shape not in check_mask:
        check_mask[shape] = 1 - np.indices(shape).sum(axis=0) % 2
        check_mask[shape] = torch.Tensor(check_mask[shape])
        
    if to_device and shape not in check_mask_device:
        check_mask_device[shape] = check_mask[shape].to(device)
        
    return check_mask_device[shape] if to_device else check_mask[shape]

chan_mask = {}
chan_mask_device = {}
def channel_mask(shape, to_device=True):
    assert len(shape) == 3, shape
    assert shape[0] % 2 == 0, shape
    global chan_mask, chan_mask_device
    if shape not in chan_mask:
        chan_mask[shape] = torch.cat([torch.zeros((shape[0] // 2, shape[1], shape[2])),
                                      torch.ones((shape[0] // 2, shape[1], shape[2])),],
                                      dim=0)
        assert chan_mask[shape].shape == shape, (chan_mask[shape].shape, shape)
        
    if to_device and shape not in chan_mask_device:
        chan_mask_device[shape] = chan_mask[shape].to(device)
        
    return chan_mask_device[shape] if to_device else chan_mask[shape]

Main Architecture 

In [11]:
class NormalizingFlowMNist(nn.Module):
    EPSILON = 1e-5
    
    def __init__(self, num_coupling=6, num_final_coupling=4, planes=64):
        super(NormalizingFlowMNist, self).__init__()
        self.num_coupling = num_coupling
        self.num_final_coupling = num_final_coupling
        self.shape = (1, 28, 28)
        
        self.planes = planes
        self.s = nn.ModuleList()
        self.t = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        # Learnable scalar scaling parameters for outputs of S and T
        self.s_scale = nn.ParameterList()
        self.t_scale = nn.ParameterList()
        self.t_bias = nn.ParameterList()
        self.shapes = []
      
        shape = self.shape
        for i in range(num_coupling):
            self.s.append(bottleneck_backbone(shape[0], planes))
            self.t.append(bottleneck_backbone(shape[0], planes))
            
            self.s_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_bias.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            
            self.norms.append(BatchNorm2d(shape[0]))
            
            self.shapes.append(shape)
           
            if i % 6 == 2:
                shape = (4 * shape[0], shape[1] // 2, shape[2] // 2)
                
            if i % 6 == 5:
                # Factoring out half the channels
                shape = (shape[0] // 2, shape[1], shape[2])
                planes = 2 * planes
       
        # Final coupling layers checkerboard
        for i in range(num_final_coupling):
            self.s.append(bottleneck_backbone(shape[0], planes))
            self.t.append(bottleneck_backbone(shape[0], planes))
            
            self.s_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_scale.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            self.t_bias.append(torch.nn.Parameter(torch.zeros(shape), requires_grad=True))
            
            self.norms.append(BatchNorm2d(shape[0]))
            
            self.shapes.append(shape)
           
        self.validation = False
    
    def validate(self):
        self.eval()
        self.validation = True
        
    def train(self, mode=True):
        nn.Module.train(self, mode)
        self.validation = False

    def forward(self, x):
        if self.training or self.validation:
            s_vals = []
            norm_vals = []
            y_vals = []
            
            for i in range(self.num_coupling):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape) if i % 6 < 3 else channel_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = (self.t_scale[i]) * self.t[i](mask * x) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * x))
                y = mask * x + (1 - mask) * (x * torch.exp(s) + t)
                s_vals.append(torch.flatten((1 - mask) * s))
               
                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)
                    
                if i % 6 == 2:
                    y = torch.nn.functional.pixel_unshuffle(y, 2)
                    
                if i % 6 == 5:
                    factor_channels = y.shape[1] // 2
                    y_vals.append(torch.flatten(y[:, factor_channels:, :, :], 1))
                    y = y[:, :factor_channels, :, :]
                    
                x = y
                
            # Final checkboard coupling
            for i in range(self.num_coupling, self.num_coupling + self.num_final_coupling):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = (self.t_scale[i]) * self.t[i](mask * x) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * x))
                y = mask * x + (1 - mask) * (x * torch.exp(s) + t)
                s_vals.append(torch.flatten((1 - mask) * s))
                
                if self.norms[i] is not None:
                    y, norm_loss = self.norms[i](y, validation=self.validation)
                    norm_vals.append(norm_loss)
                
                x = y

            y_vals.append(torch.flatten(y, 1))
            
            # Return outputs and vars needed for determinant
            return (torch.flatten(torch.cat(y_vals, 1), 1),
                    torch.cat(s_vals), 
                    torch.cat([torch.flatten(v) for v in norm_vals]) if len(norm_vals) > 0 else torch.zeros(1),
                    torch.cat([torch.flatten(s) for s in self.s_scale]))
        else:
            y = x
            y_remaining = y
           
            layer_vars = np.prod(self.shapes[-1])
            y = torch.reshape(y_remaining[:, -layer_vars:], (-1,) + self.shapes[-1])
            y_remaining = y_remaining[:, :-layer_vars]
            
            # Reversed final checkboard coupling
            for i in reversed(range(self.num_coupling, self.num_coupling + self.num_final_coupling)):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
                
                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)
              
                t = (self.t_scale[i]) * self.t[i](mask * y) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * y))
                x = mask * y + (1 - mask) * ((y - t) * torch.exp(-s))
               
                y = x           
          
            layer_vars = np.prod(shape)
            y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
            y_remaining = y_remaining[:, :-layer_vars]
            
            # Multi-scale coupling layers
            for i in reversed(range(self.num_coupling)):
                shape = self.shapes[i]
                mask = checkerboard_mask(shape) if i % 6 < 3 else channel_mask(shape)
                mask = mask if i % 2 == 0 else (1 - mask)
              
                if self.norms[i] is not None:
                    y, _ = self.norms[i](y)
                    
                t = (self.t_scale[i]) * self.t[i](mask * y) + (self.t_bias[i])
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * y))
                x = mask * y + (1 - mask) * ((y - t) * torch.exp(-s))
               
                if i % 6 == 3:
                    x = torch.nn.functional.pixel_shuffle(x, 2)
                    
                y = x
                
                if i > 0 and i % 6 == 0:
                    layer_vars = np.prod(shape)
                    y = torch.cat((y, torch.reshape(y_remaining[:, -layer_vars:], (-1,) + shape)), 1)
                    y_remaining = y_remaining[:, :-layer_vars]
            
            assert np.prod(y_remaining.shape) == 0
            return x

Training

In [None]:
PI = torch.tensor(np.pi).to(device)
def loss_fn(y, s, norms, scale, batch_size):
    # same as in 2D 
    logpx = -torch.sum(0.5 * torch.log(2 * PI) + 0.5 * y**2)
    det = torch.sum(s)
    norms = torch.sum(norms)
    reg = 5e-5 * torch.sum(scale ** 2)
    loss = -(logpx + det + norms) + reg
    return torch.div(loss, batch_size), (-logpx, -det, -norms, reg)
# Training
def pre_process(x):
    x = x * 255.
    noise = torch.rand(x.shape, device=x.device)
    x = x + noise
    # Apply transform to deal with boundary effects (see realNVP paper)
    x = torch.logit(0.05 + 0.90 * x / 256)
    return x

def post_process(x):
    # Convert back to integer values
    return torch.clip(x, min=0, max=1)

train_dataset = datasets.MNIST('data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                               ]))

test_dataset = datasets.MNIST('data', train=False, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                              ]))

def train_loop(dataloader, model, loss_fn, optimizer, batch_size, report_iters=10, num_pixels=28*28):
    size = len(dataloader)
    prev = []
    for batch, (X, _) in enumerate(dataloader):
        # Transfer to GPU
        X = pre_process(X)
        X = X.to(device)
        
        # Compute prediction and loss
        y, s, norms, scale = model(X)
        loss, comps = loss_fn(y, s, norms, scale, batch_size)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        prev = [(name, x, x.grad) for name, x in model.named_parameters(recurse=True)]
        optimizer.step()

        if batch % report_iters == 0:
            loss, current = loss.item(), batch
            # Account for x/255 preprocessing
            loss += num_pixels * np.log(255)
            print(f"loss: {loss:.2f} = -logpx[{comps[0]:.1f}], -det[{comps[1]:.1f}], -norms[{comps[2]:.1f}], reg[{comps[3]:.4f}]"
                  f"; bits/pixel: {loss / num_pixels / np.log(2):>.2f}  [{current:>5d}/{size:>5d}]")
            
        
def test_loop(dataloader, model, loss_fn, num_pixels=28*28):
    size = len(dataloader)
    num_batches = len(dataloader)
    test_loss = 0

    with torch.no_grad():
        model.validate()
        for X, _ in dataloader:
            X = pre_process(X)
            X = X.to(device)
            y, s, norms, scale = model(X)
            loss, _ = loss_fn(y, s, norms, scale, batch_size)
            test_loss += loss
            
        model.train()

    
    test_loss /= num_batches
    # Account for x/255 preprocessing
    test_loss += num_pixels * np.log(255)
    print(f"Test Error: \n Avg loss: {test_loss:.2f}; {test_loss / num_pixels / np.log(2):.2f} \n")
    return test_loss

learning_rate = 0.0005
batch_size = 50
epochs = 10 # increase for better results

model = NormalizingFlowMNist(num_coupling=12, num_final_coupling=4, planes=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model.train()

best_validation = None
PATH = 'checkpoints/' # Directory path


os.makedirs(PATH, exist_ok=True) 

print("\n--- Starting Training ---") 

for t in range(epochs):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True) # Recreate loader for shuffle
    print(f"\n--- Epoch {t+1}/{epochs} ---") 
    train_loop(train_loader, model, loss_fn, optimizer, batch_size)
    validation_loss = test_loop(test_loader, model, loss_fn)

    torch.save({
        'epoch': t,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': validation_loss,
    }, PATH + f'mnist-{t}.model')

    if best_validation is None or validation_loss < best_validation:
        best_validation = validation_loss
        best_path = PATH + f'mnist-{t}.model' # Store the path of the best model

    scheduler.step()

print("Done - Best model saved at epoch corresponding to:", best_path) # Print path 

Evaluate

In [None]:
model = NormalizingFlowMNist(num_coupling=12, num_final_coupling=4, planes=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

checkpoint = torch.load(best_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# DEBUG - Checkmodel[s]
model.validate()
with torch.no_grad():
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) #shuffle=True)
    for x, _ in train_loader:
        x_pre = pre_process(x).to(device)
        y, s, norms, scale = model(x_pre)
        print(y.shape)
        break

model.train()
model.eval()
with torch.no_grad():
    xp = model(y)
    x_post = post_process(xp)

diff = x.to(device) - x_post
print(torch.any(torch.abs(diff) > 1 / 255))

print(diff.shape)
for i in range(batch_size):
    if torch.any(torch.abs(diff[i]) > 1 / 255):
        #print(diff[i])
        for j in range(28):
            for k in range(28):
                if torch.any(torch.abs(diff[i, 0, j, k]) > 1 / 255):
                    print(i, 1, j, k, diff[i, 0, j, k].cpu().numpy())
                    break
        break

s = pd.Series(torch.flatten(y).cpu().numpy())
print(s.describe())
s.hist(bins=50)
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
    img, label = train_dataset[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(str(label))
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

In [None]:
model.eval()

cols, rows = 8, 8
with torch.no_grad():
    X = torch.Tensor(torch.normal(torch.zeros(cols * rows, 28 * 28 * 1),
                                  torch.ones(cols * rows, 28 * 28 * 1))).to(device)
    Y = model(X)
    samples = post_process(Y).cpu().numpy()

figure = plt.figure(figsize=(15, 15))
for i in range(1, cols * rows + 1):
    img = samples[i - 1]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Visualizations 

1) Generating samples resembling a specific digit

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
import random


def get_latents_for_digit(model, dataset, digit, num_samples, batch_size, device):
    model.validate() 
    latents = []
    indices = [i for i, (_, label) in enumerate(dataset) if label == digit]
    if len(indices) < num_samples:
        print(f"Warning: Only found {len(indices)} samples for digit {digit}, requested {num_samples}")
        num_samples = len(indices)

    # Get a random subset of indices for the target digit
    random_indices = random.sample(indices, num_samples)
    subset = Subset(dataset, random_indices)
    loader = DataLoader(subset, batch_size=batch_size)

    with torch.no_grad():
        for X, _ in loader:
            X_pre = pre_process(X).to(device)
            # In validate mode, forward pass maps image -> latent
            y, _, _, _ = model(X_pre)
            latents.append(y.cpu()) #

    model.train(False) 
    model.eval()
    if not latents:
         raise ValueError(f"No samples found for digit {digit}")
    return torch.cat(latents, dim=0)

def generate_digit_samples(model, dataset, digit, num_to_generate, num_ref_samples, noise_std, batch_size, device):
    print(f"\n--- Generating {num_to_generate} samples for digit {digit} ---")
    ref_latents = get_latents_for_digit(model, dataset, digit, num_ref_samples, batch_size, device)

    z_mean = ref_latents.mean(dim=0, keepdim=True)
    print(f"Calculated mean latent vector shape: {z_mean.shape}")

    z_samples = z_mean.repeat(num_to_generate, 1)
    noise = torch.randn_like(z_samples) * noise_std
    z_samples = (z_samples + noise).to(device)
    print(f"Generated noisy latent samples shape: {z_samples.shape}")
    model.eval() # Ensure inverse pass mode
    generated_images_pre = []
    gen_batch_size = min(batch_size, num_to_generate)
    with torch.no_grad():
         for i in range(0, num_to_generate, gen_batch_size):
             z_batch = z_samples[i:i+gen_batch_size]
             # In eval mode, forward pass maps latent -> image
             y_gen = model(z_batch)
             generated_images_pre.append(y_gen.cpu())

    generated_images_pre = torch.cat(generated_images_pre, dim=0)
    print(f"Generated images (before post-processing) shape: {generated_images_pre.shape}")
    generated_images = post_process(generated_images_pre)
    print(f"Generated images (after post-processing) shape: {generated_images.shape}")

    cols = 8
    rows = (num_to_generate + cols - 1) // cols
    figure = plt.figure(figsize=(cols * 1.5, rows * 1.5))
    plt.suptitle(f"Generated Samples for Digit: {digit}", fontsize=16)
    for i in range(num_to_generate):
        img = generated_images[i]
        figure.add_subplot(rows, cols, i + 1)
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
    plt.show()

# --- Parameters for Generation ---
DIGIT_TO_GENERATE = 3   
NUM_SAMPLES_TO_GENERATE = 32 # How many samples to show
NUM_REFERENCE_SAMPLES = 100 # How many real images to use to find the average latent space
NOISE_STANDARD_DEVIATION = 0.6 # How much noise to add to the average latent vector (controls variety)


checkpoint = torch.load(best_path) # Or specify the path directly
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

generate_digit_samples(
    model=model,
    dataset=train_dataset,
    digit=DIGIT_TO_GENERATE,
    num_to_generate=NUM_SAMPLES_TO_GENERATE,
    num_ref_samples=NUM_REFERENCE_SAMPLES,
    noise_std=NOISE_STANDARD_DEVIATION,
    batch_size=batch_size,
    device=device
)

2) Interpolate between two specific images and save as a video 

In [None]:
import imageio # For creating Videos
from IPython.display import Video as IPVideo, display 
import random 
def interpolate_images_to_video(model, dataset, source_idx, target_idx, num_steps, filename, fps, device, batch_size_override=None):
    print(f"\n--- Interpolating from index {source_idx} to {target_idx} into video ---")
    assert filename.lower().endswith('.mp4'), "Filename must end with .mp4"

    img_src, z_src, label_src = get_image_and_latent(model, dataset, source_idx, device) # FiX this 
    img_tgt, z_tgt, label_tgt = get_image_and_latent(model, dataset, target_idx, device)

    print(f"Source Label: {label_src}, Target Label: {label_tgt}")
    print(f"Latent vector shape: {z_src.shape}")

    alphas = torch.linspace(0, 1, num_steps)
    interpolated_latents = []
    for alpha in alphas:
        z_interp = (1 - alpha) * z_src + alpha * z_tgt
        interpolated_latents.append(z_interp)

    interpolated_latents = torch.cat(interpolated_latents, dim=0).to(device)
    print(f"Interpolated latents batch shape: {interpolated_latents.shape}")

    model.eval() # Ensure inverse pass mode
    generated_images_pre = []
    gen_batch_size = batch_size_override if batch_size_override else min(128, num_steps) # Use override or default
    print(f"Using generation batch size: {gen_batch_size}")
    with torch.no_grad():
         for i in range(0, num_steps, gen_batch_size):
             z_batch = interpolated_latents[i:i+gen_batch_size]
             # In eval mode, forward pass maps latent -> image
             y_gen_output = model(z_batch)
             y_gen_tensor = y_gen_output[0] if isinstance(y_gen_output, tuple) else y_gen_output
             generated_images_pre.append(y_gen_tensor.cpu())

    generated_images_pre = torch.cat(generated_images_pre, dim=0)
    print(f"Generated interpolated images (before post-processing) shape: {generated_images_pre.shape}")
    interpolated_images = post_process(generated_images_pre)
    print(f"Generated interpolated images (after post-processing) shape: {interpolated_images.shape}")

    # --- Create Video Frames ---
    frames = []
    print("Preparing video frames...")
    for i in range(num_steps):
        img_tensor = interpolated_images[i].squeeze() # Remove channel dim (H, W)
        # Convert grayscale (H, W) float [0,1] to RGB uint8 [0,255] (H, W, 3)
        # Most video codecs prefer RGB. Stack the grayscale channel 3 times.
        frame_np = np.stack([img_tensor.numpy()] * 3, axis=-1)
        frame_np_uint8 = (frame_np * 255).astype(np.uint8)
        frames.append(frame_np_uint8)
    print(f"Prepared {len(frames)} frames.")

    # --- Save as MP4 Video ---
    print(f"Saving MP4 video to {filename}...")
    try:
        imageio.mimwrite(filename, frames, fps=fps, codec='libx264', quality=8, macro_block_size=1)
        print("MP4 video saved successfully.")
    except Exception as e:
        print("\n-------------------------------------")
        print(f"Error saving MP4 video: {e}")
        print("Ensure imageio and ffmpeg backend are installed correctly.")
        print("Try: pip install imageio-ffmpeg")
        print("-------------------------------------")
        return # Exit if saving failed
    print("\nAttempting to display generated video:")
    try:
        display(IPVideo(filename, embed=True, width=200)) # Smaller display width
    except FileNotFoundError:
        print(f"Video file '{filename}' not found. Could not display.")
    except Exception as e:
        print(f"Could not display video inline: {e}")


# --- Parameters for Interpolation Video ---
SOURCE_IMAGE_INDEX = 10   # Index from the dataset 
TARGET_IMAGE_INDEX = 20   # Index from the dataset
INTERPOLATION_STEPS = 60 # Number of frames in the video 
VIDEO_FPS = 15           # Frames per second for the video
VIDEO_FILENAME = f'mnist_interpolation_{SOURCE_IMAGE_INDEX}_to_{TARGET_IMAGE_INDEX}.mp4'
GENERATION_BATCH_SIZE = 64 

checkpoint = torch.load(best_path, map_location=device) # Load to correct device
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set to evaluation mode

interpolate_images_to_video(
    model=model,
    dataset=test_dataset, # Use train or test dataset
    source_idx=SOURCE_IMAGE_INDEX,
    target_idx=TARGET_IMAGE_INDEX,
    num_steps=INTERPOLATION_STEPS,
    filename=VIDEO_FILENAME,
    fps=VIDEO_FPS,
    device=device,
    batch_size_override=GENERATION_BATCH_SIZE # Pass batch size for generation
)

3) Generate an MP4 video of multiple interpolations

In [None]:
import torchvision.utils as vutils
def generate_interpolation_grid_video(model, dataset, pairs, num_steps, grid_rows, filename, fps, device):
    print(f"\n--- Generating interpolation video for {len(pairs)} pairs ---")
    num_pairs = len(pairs)
    assert num_pairs == grid_rows * grid_rows, "Number of pairs must match grid size (rows*rows)"
    assert filename.lower().endswith('.mp4'), "Filename should end with .mp4 for video output"

    z_sources = []
    z_targets = []

    print("Getting latent vectors...")
    for source_idx, target_idx in pairs:
        _, z_src, _ = get_image_and_latent(model, dataset, source_idx, device) # TODO: Fix this
        _, z_tgt, _ = get_image_and_latent(model, dataset, target_idx, device)
        z_sources.append(z_src)
        z_targets.append(z_tgt)

    z_sources = torch.cat(z_sources, dim=0).to(device)
    z_targets = torch.cat(z_targets, dim=0).to(device)
    print(f"Source latents batch shape: {z_sources.shape}")
    print(f"Target latents batch shape: {z_targets.shape}")

    model.eval() # Ensure inverse pass mode
    frames = []
    alphas = torch.linspace(0, 1, num_steps)
    print("Generating video frames...")
    for i, alpha in enumerate(alphas):
        z_interp_batch = (1 - alpha) * z_sources + alpha * z_targets

        with torch.no_grad():
            generated_images_pre = model(z_interp_batch) # Latent -> Image

        # Post-process
        interpolated_images = post_process(generated_images_pre).cpu()
        grid_img = vutils.make_grid(interpolated_images, nrow=grid_rows, padding=2, normalize=False)
        frame = grid_img.permute(1, 2, 0).numpy()
        frame = (frame * 255).astype(np.uint8)
        frames.append(frame)
        print(f"  Generated frame {i+1}/{num_steps} for alpha = {alpha:.2f}", end='\r')

    print(f"Saving MP4 video to {filename}...")
    try:
        imageio.mimwrite(filename, frames, fps=fps, codec='libx264', quality=8, macro_block_size=1)
        print("MP4 video saved successfully.")
    except Exception as e:
        print("\n-------------------------------------")
        print(f"Error saving MP4 video: {e}")
        print("This often means 'ffmpeg' is not installed or not found by imageio.")
        print("Try installing it:")
        print("  Using pip:   pip install imageio-ffmpeg")
        print("  Using conda: conda install ffmpeg -c conda-forge")
        print("  Or install ffmpeg via your system's package manager (apt, brew, etc.)")
        print("-------------------------------------")


# --- Parameters for Grid Video ---
NUM_GRID_ROWS = 4 # Creates a 4x4 grid
NUM_PAIRS = NUM_GRID_ROWS * NUM_GRID_ROWS
GRID_INTERPOLATION_STEPS = 120 # Number of frames for smoother video
VIDEO_FPS = 15 # Frames per second for the output video

# You could also manually select specific indices for interesting transitions
image_pairs = []
available_indices = list(range(len(test_dataset))) 
random.shuffle(available_indices)
for _ in range(NUM_PAIRS):
     if len(available_indices) < 2:
         print("Warning: Not enough unique indices left in dataset for pairs.")
         break
     idx1 = available_indices.pop()
     idx2 = available_indices.pop()
     image_pairs.append((idx1, idx2))

print("Selected pairs (source_idx, target_idx):", image_pairs)

# best_path should be defined from your training loop
checkpoint = torch.load(best_path) 
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)


video_filename = 'mnist_interpolation_grid.mp4'
generate_interpolation_grid_video( 
    model=model,
    dataset=test_dataset, # Use train or test dataset
    pairs=image_pairs,
    num_steps=GRID_INTERPOLATION_STEPS,
    grid_rows=NUM_GRID_ROWS,
    filename=video_filename, 
    fps=VIDEO_FPS,         
    device=device
)

# Optional: Display the generated Video inline (if in Jupyter and file saved)
print("\nAttempting to display generated video (may not work in all environments):")
try:
    display(IPVideo(video_filename, embed=True, width=400))
except FileNotFoundError:
    print(f"Video file '{video_filename}' not found. Could not display.")
except Exception as e:
    print(f"Could not display video inline: {e}")

In [None]:
# TODO : Fix some functions