# Exercise 3: Variational Autoencoders
---
<!-- ## Table of Contents
1. [Introduction](#introduction)
2. [MNIST](#part-i-mnist)
    - [](#tool-pytorch-lightning) -->

## Introduction

In this exercise, you will build and apply variational autoencoders to the [MNIST dataset](https://yann.lecun.com/exdb/mnist/) (hand-written digits). While VAEs may no longer top the race in terms of generation quality compared to newer frameworks such as diffusion models, they remain a critical tool in various applications. VAEs are not only useful for data generation but also for tasks like anomaly detection, semi-supervised learning, and feature selection. In reinforcement learning, for instance, VAEs are employed to learn compact, informative representations of the environment, which can simplify state-space representations and improve policy learning. They also play a crucial role in disentangling latent features, aiding in interpretable and controllable generative processes.

> MNIST has been used for many different projects, from handwritting recognition tasks to generative AI (such as this one!). Check out the [link](https://yann.lecun.com/exdb/mnist/) to see some of these projects, and maybe even try a couple yourself.

In [33]:
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

## Setting up the dataset and training parameters:

Before building the model, we need to define some basic parameters and understand the shape of our data. MNIST images are ***28x28 pixels*** and since they are greyscale, ***they only have a single channel***. 

We also need to decide the size of the ***latent space*** - This is the dimensionality of the vector where the encoder compresses the input. A common choice is ***20 dimensions***, but you can experiment with smaller or larger latent spaces to see how it affects the quality of the generated images. This is considered the ***bottleneck*** of the VAE - The compressed representation of the data. ***Too small*** and the model cannot capture the variations in the data, reconstruction will lose details. ***too large*** and the latent space may become under-regularized, leading to poor generative properties (e.g., sampling from the latent space produces meaningless images).

> Practical rule for MNIST: 10-50 dimensions often work well (hence why we reocmmend first trying with a value of 20 and then experimenting). Try a range of values (even if it comes out as expected first try) to see what different latent space dimensions does to a VAE.

You will also want to define other important parameters such as ***batch size***, ***learning rate***, ***hidden dimensions*** and ***epochs***.

In [34]:
dataset_path = '~/datasets'
cuda = False

DEVICE = torch.device("cuda" if cuda else "cpu")

batch_size = 100

x_dim = 784
hidden_dim = 256
latent_dim = 10

lr = 1e-3

epochs = 50

## Loading the data:

Since the MNIST dataset is so widely used, it has become one of the many `torchvision` datasets which can be loaded as a package. For this reason, we use the available packages to load our train and test datasets/loaders. This is code we provide for you, however you can also find documentation on other `torchvision.datasets` if you want to try out other cool projects using online data!

In [35]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

mnist_transform = transforms.Compose([transforms.ToTensor(), ])

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_dataset = MNIST(dataset_path, train=True, download=True, transform=mnist_transform)
test_dataset = MNIST(dataset_path, train=False, download=True, transform=mnist_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

## Understanding the encoder-decoder structure in a VAE

---

### Encoder: Mapping inputs to a dsitribution

The encoder's role is to compress the input image into a representation in the latent space. But unlike a standard autoencoder, the VAE encoder outputs a **distribution**, not a single vector. For each input ***x***, the encoder produces two vectors:
- **mean vector** ($\mu$) - the center of the latent distribution
- **Log-variance vector** (log $\sigma^2$) - describes how spread out the distribution is.

Formally:

$(\mu, log(\sigma^2))$ = Encoder($x$)

This probabalistic encoding allows the latent space to capture both **what features are important** and **how uncertain the model is** about them.

In [36]:
# LOADING MODEL:

"""
simple Gaussian MLP Encoder and Decoder
"""

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc_input = nn.Linear(x_dim, hidden_dim)
        self.fc_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_log_var = nn.Linear(hidden_dim, latent_dim)

        self.ReLU = nn.ReLU(0.2)

        self.training = True

    def forward(self, x):
        h = self.ReLU(self.fc_input(x))
        h = self.ReLU(self.fc_input2(h))
        z_mean = self.fc2_mean(h)
        z_log_var = self.fc2_log_var(h)
        return z_mean, z_log_var


### Decoder: Reconstructing the input

The decoder takes a point from the latent space and attempts to reconstruct the original image.

$\hat x$ = Decoder($z$)

Here $\hat x$ is the reconstructed image, and its similarity to the original input $x$ is measured using the **reconstruction loss**. If the encoder has captures the right features, the decoder can recreate images that look very much like the originals.

In [37]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.fc_hidden = nn.Linear(latent_dim, hidden_dim)
        self.fc_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_output = nn.Linear(hidden_dim, x_dim)

        self.ReLU= nn.ReLU(0.2)

        self.training = True

    def forward(self, x):
        h = self.ReLU(self.fc_hidden(x))
        h = self.ReLU(self.fc_hidden2(h))

        x_hat = torch.sigmoid(self.fc_output(h))
        return x_hat

### Putting it together: the reparameterization trick

To connect the encoder and decoder, we need to sample a latent vector $z$ from the distribution produced by the encoder. However, naive sampling would break backpropogation. The solution sis the **reparameterization trick**:

$z = \mu + \sigma \odot \epsilon$

$\epsilon ~ \mathcal{N}(0, I)$

where: 
- $\mu$ is the mean from the encoder
- $\sigma$ = $exp(0.5 * log (\sigma^2))$ is the standard deviation
- $\epsilon$ is random noise drawn from a standard normal distribution
- $\odot$ denotes element-wise multiplication.

This formulation keeps the randomness while allowing gradients to flow, making the model trainable. By combining **probabalistic encoding**, **differentiable sapling**, and **decoding**, the VAE learns a **smooth and contineous latent space**. This not only enables faithful reconstruction of digits but also makes it possible to generate **entirely new samples** by drawing random vectors from the latent space.


In [38]:
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)
        z = mean + var * epsilon
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)

        return x_hat, mean, log_var

In [39]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=x_dim)

model = VAE(Encoder=encoder, Decoder=decoder).to(DEVICE)

## Setting up the optimizer and loss function

Now that the model is defined, we need two more ingredients before we can start training: 
1. A **loss function** that tells the model how well its doing.
2. An **optimizer** that updates the model's parameters based on that loss.

---

### The VAE loss function

The loss in the the VAE has two parts: 

1. **Reconstruction loss**:
    - Measures how close the reconstructed imahe $\hat x$ is to the original input image $x$.
    - Here, we use **binary cross-entropy (BCE)**, summed over all pixels.
    - Essentially the model gets penalized if it can't recreate the input digits correctly.

    $\mathrm{recon} = \mathrm{BCE}(x, \hat x)$

2. **KL divergence loss**:
    - Regularizes the latent space so that the encoded distribution $mathcal{N}(\mu, \sigma^2)$ stay close to a standard normal distribution $mathcal{N}(0, I)$. 
    - This keeps the latent space smooth and ensures that sampling random points produces meaningful digits

    $\mathrm{KL} = - 0.5 \sum(1 + log(\sigma^2) - \mu^2, - \sigma^2)$
    
    The final loss is the sum of both terms (with an optional scaling factor $\beta$ to control the strength of the KL term):

    $\mathcal{L} = \mathrm{recon} + \beta * KL$

    Dividing by the batch size helps keep th eloss values stable.

### The optimizer:

To train the model we use the **Adam optimizer**. Adam adapts the learning rate for each parameter, making training faster and more stable than just standard gradient descent. Therefore to make it run we pass in the `model parameters` and the `learning rate`.

In [40]:
from torch.optim import Adam
import torch.nn.functional as F

# reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x, x_hat, mean, log_var, beta=1.0):
    # reconstruction loss (BCE summed over pixels)
    recon = F.binary_cross_entropy(x_hat, x, reduction="sum")

    # KL divergence term
    kl = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    # normalize by batch size for stability
    return (recon + beta * kl) / x.size(0)

optimizer = Adam(model.parameters(), lr=lr)


## The training loop:

With the model, loss, and optimizer defined, we can now put everything together in a training loop. Training a VAE looks similar to training other neural networks, but we track three key values:  

1. **Reconstruction loss** – measures how well the decoder reproduces the input images.  
2. **KL divergence** – regularizes the latent space so it stays close to a standard normal distribution.  
3. **Total loss** – the sum of the two, which the optimizer minimizes.  

At each epoch:  
- We loop through the training batches and pass the images through the encoder and decoder.  
- We compute the reconstruction and KL terms separately to better monitor how the model is learning.  
- The gradients are reset (`optimizer.zero_grad()`), the loss is backpropagated (`loss.backward()`), and the optimizer updates the model’s parameters (`optimizer.step()`).  

Using `tqdm`, we also show a progress bar with the current loss values for each batch, making it easier to see improvements during training. At the end of each epoch, we print the average losses so we can track the overall learning progress.  

By monitoring both reconstruction and KL divergence, we ensure the model is **balancing accurate reconstructions with a well-structured latent space**, which is the essence of training a VAE.  


In [43]:
from tqdm import tqdm

print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    overall_recon = 0
    overall_kl = 0
    
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
    
    for batch_idx, (x, _) in loop:
        x = x.view(batch_size, x_dim).to(DEVICE)
        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)

        # Compute separate components
        recon = F.binary_cross_entropy(x_hat, x, reduction="sum") / x.size(0)
        kl = (-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())) / x.size(0)
        loss = recon + kl  # optionally multiply kl by beta if needed: recon + beta * kl

        loss.backward()
        optimizer.step()

        overall_loss += loss.item()
        overall_recon += recon.item()
        overall_kl += kl.item()
        
        # Update tqdm with current batch components
        loop.set_postfix(loss=loss.item(), recon=recon.item(), kl=kl.item())
    
    avg_loss = overall_loss / len(train_loader)
    avg_recon = overall_recon / len(train_loader)
    avg_kl = overall_kl / len(train_loader)
    
    print(f"\tEpoch {epoch + 1} complete! "
          f"Average Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")

print("Finish!!")


Start training VAE...


Epoch 1: 100%|██████████| 600/600 [00:13<00:00, 45.76it/s, kl=12.2, loss=137, recon=125]


	Epoch 1 complete! Average Loss: 175.4245, Recon: 168.3537, KL: 7.0708


Epoch 2: 100%|██████████| 600/600 [00:13<00:00, 43.60it/s, kl=14.6, loss=123, recon=109] 


	Epoch 2 complete! Average Loss: 128.3286, Recon: 114.6161, KL: 13.7124


Epoch 3: 100%|██████████| 600/600 [00:14<00:00, 40.39it/s, kl=15.9, loss=116, recon=100] 


	Epoch 3 complete! Average Loss: 120.0213, Recon: 104.9933, KL: 15.0280


Epoch 4: 100%|██████████| 600/600 [00:19<00:00, 30.04it/s, kl=15.6, loss=114, recon=98.5]


	Epoch 4 complete! Average Loss: 115.4076, Recon: 99.4339, KL: 15.9737


Epoch 5: 100%|██████████| 600/600 [00:13<00:00, 43.77it/s, kl=16.2, loss=110, recon=93.9]


	Epoch 5 complete! Average Loss: 112.8869, Recon: 96.5785, KL: 16.3085


Epoch 6: 100%|██████████| 600/600 [00:19<00:00, 30.96it/s, kl=17.8, loss=109, recon=91.6]


	Epoch 6 complete! Average Loss: 111.2296, Recon: 94.7058, KL: 16.5238


Epoch 7: 100%|██████████| 600/600 [00:13<00:00, 43.28it/s, kl=17.3, loss=114, recon=97.2]


	Epoch 7 complete! Average Loss: 110.0461, Recon: 93.3423, KL: 16.7038


Epoch 8: 100%|██████████| 600/600 [00:18<00:00, 32.28it/s, kl=16.8, loss=112, recon=94.9]


	Epoch 8 complete! Average Loss: 109.0284, Recon: 92.1429, KL: 16.8855


Epoch 9: 100%|██████████| 600/600 [00:13<00:00, 44.13it/s, kl=17.7, loss=109, recon=91.4]


	Epoch 9 complete! Average Loss: 108.0079, Recon: 90.7942, KL: 17.2137


Epoch 10: 100%|██████████| 600/600 [00:18<00:00, 32.41it/s, kl=18.4, loss=105, recon=86.5] 


	Epoch 10 complete! Average Loss: 107.0722, Recon: 89.5558, KL: 17.5164


Epoch 11: 100%|██████████| 600/600 [00:13<00:00, 43.99it/s, kl=17.4, loss=105, recon=87.8] 


	Epoch 11 complete! Average Loss: 106.3524, Recon: 88.6537, KL: 17.6987


Epoch 12: 100%|██████████| 600/600 [00:20<00:00, 29.48it/s, kl=18.2, loss=103, recon=84.5] 


	Epoch 12 complete! Average Loss: 105.7463, Recon: 87.9139, KL: 17.8324


Epoch 13: 100%|██████████| 600/600 [00:14<00:00, 42.76it/s, kl=18.4, loss=106, recon=88.1] 


	Epoch 13 complete! Average Loss: 105.2371, Recon: 87.2846, KL: 17.9524


Epoch 14: 100%|██████████| 600/600 [00:18<00:00, 32.60it/s, kl=18.3, loss=106, recon=87.8] 


	Epoch 14 complete! Average Loss: 104.8019, Recon: 86.7580, KL: 18.0439


Epoch 15: 100%|██████████| 600/600 [00:13<00:00, 43.07it/s, kl=18, loss=110, recon=92.3]   


	Epoch 15 complete! Average Loss: 104.3941, Recon: 86.2836, KL: 18.1104


Epoch 16: 100%|██████████| 600/600 [00:18<00:00, 31.69it/s, kl=17.7, loss=107, recon=89.5] 


	Epoch 16 complete! Average Loss: 104.0377, Recon: 85.8576, KL: 18.1801


Epoch 17: 100%|██████████| 600/600 [00:13<00:00, 44.58it/s, kl=18.4, loss=106, recon=87.1] 


	Epoch 17 complete! Average Loss: 103.7658, Recon: 85.5201, KL: 18.2457


Epoch 18: 100%|██████████| 600/600 [00:18<00:00, 32.14it/s, kl=18.5, loss=98.1, recon=79.6]


	Epoch 18 complete! Average Loss: 103.4634, Recon: 85.1905, KL: 18.2729


Epoch 19: 100%|██████████| 600/600 [00:15<00:00, 39.73it/s, kl=18, loss=95.7, recon=77.8]  


	Epoch 19 complete! Average Loss: 103.1992, Recon: 84.8775, KL: 18.3217


Epoch 20: 100%|██████████| 600/600 [00:15<00:00, 38.06it/s, kl=18.8, loss=105, recon=86]   


	Epoch 20 complete! Average Loss: 103.0356, Recon: 84.6533, KL: 18.3823


Epoch 21: 100%|██████████| 600/600 [00:17<00:00, 33.44it/s, kl=18.6, loss=106, recon=87.3] 


	Epoch 21 complete! Average Loss: 102.7546, Recon: 84.3397, KL: 18.4149


Epoch 22: 100%|██████████| 600/600 [00:14<00:00, 40.32it/s, kl=18.5, loss=106, recon=87.4] 


	Epoch 22 complete! Average Loss: 102.5360, Recon: 84.1037, KL: 18.4324


Epoch 23: 100%|██████████| 600/600 [00:17<00:00, 34.87it/s, kl=18.4, loss=105, recon=86.9] 


	Epoch 23 complete! Average Loss: 102.3837, Recon: 83.9115, KL: 18.4722


Epoch 24: 100%|██████████| 600/600 [00:15<00:00, 38.99it/s, kl=18.1, loss=99.1, recon=81]  


	Epoch 24 complete! Average Loss: 102.2310, Recon: 83.7262, KL: 18.5048


Epoch 25: 100%|██████████| 600/600 [00:15<00:00, 38.00it/s, kl=18.8, loss=104, recon=85.5] 


	Epoch 25 complete! Average Loss: 102.0981, Recon: 83.5678, KL: 18.5303


Epoch 26: 100%|██████████| 600/600 [00:14<00:00, 42.74it/s, kl=19, loss=98.8, recon=79.8]  


	Epoch 26 complete! Average Loss: 101.9289, Recon: 83.3896, KL: 18.5393


Epoch 27: 100%|██████████| 600/600 [00:18<00:00, 31.61it/s, kl=18.6, loss=99.5, recon=80.9]


	Epoch 27 complete! Average Loss: 101.7322, Recon: 83.1687, KL: 18.5636


Epoch 28: 100%|██████████| 600/600 [00:14<00:00, 41.55it/s, kl=18.5, loss=104, recon=85.6] 


	Epoch 28 complete! Average Loss: 101.6165, Recon: 83.0315, KL: 18.5851


Epoch 29: 100%|██████████| 600/600 [00:19<00:00, 30.29it/s, kl=18.7, loss=104, recon=85]   


	Epoch 29 complete! Average Loss: 101.5006, Recon: 82.8958, KL: 18.6048


Epoch 30: 100%|██████████| 600/600 [00:15<00:00, 38.71it/s, kl=17.7, loss=102, recon=84.1] 


	Epoch 30 complete! Average Loss: 101.3242, Recon: 82.7090, KL: 18.6152


Epoch 31: 100%|██████████| 600/600 [00:16<00:00, 35.56it/s, kl=19.1, loss=105, recon=85.9] 


	Epoch 31 complete! Average Loss: 101.2160, Recon: 82.5575, KL: 18.6586


Epoch 32: 100%|██████████| 600/600 [00:21<00:00, 28.34it/s, kl=18.9, loss=98.1, recon=79.2]


	Epoch 32 complete! Average Loss: 101.1458, Recon: 82.4620, KL: 18.6838


Epoch 33: 100%|██████████| 600/600 [00:15<00:00, 39.61it/s, kl=18.5, loss=102, recon=83.4] 


	Epoch 33 complete! Average Loss: 101.0302, Recon: 82.3508, KL: 18.6794


Epoch 34: 100%|██████████| 600/600 [00:19<00:00, 30.19it/s, kl=18.7, loss=100, recon=81.7] 


	Epoch 34 complete! Average Loss: 100.9092, Recon: 82.2048, KL: 18.7043


Epoch 35: 100%|██████████| 600/600 [00:14<00:00, 42.06it/s, kl=18.3, loss=99.9, recon=81.6]


	Epoch 35 complete! Average Loss: 100.8307, Recon: 82.1359, KL: 18.6948


Epoch 36: 100%|██████████| 600/600 [00:18<00:00, 32.48it/s, kl=18.5, loss=98, recon=79.5]  


	Epoch 36 complete! Average Loss: 100.7090, Recon: 81.9830, KL: 18.7260


Epoch 37: 100%|██████████| 600/600 [00:18<00:00, 32.01it/s, kl=19.2, loss=101, recon=81.8] 


	Epoch 37 complete! Average Loss: 100.6163, Recon: 81.8690, KL: 18.7473


Epoch 38: 100%|██████████| 600/600 [00:18<00:00, 33.23it/s, kl=18.8, loss=102, recon=82.8] 


	Epoch 38 complete! Average Loss: 100.5477, Recon: 81.7944, KL: 18.7533


Epoch 39: 100%|██████████| 600/600 [00:14<00:00, 40.11it/s, kl=18.8, loss=101, recon=82.3] 


	Epoch 39 complete! Average Loss: 100.4874, Recon: 81.7216, KL: 18.7658


Epoch 40: 100%|██████████| 600/600 [00:19<00:00, 31.35it/s, kl=18.5, loss=97.3, recon=78.9]


	Epoch 40 complete! Average Loss: 100.3986, Recon: 81.5999, KL: 18.7986


Epoch 41: 100%|██████████| 600/600 [00:16<00:00, 35.81it/s, kl=18.4, loss=101, recon=83.1] 


	Epoch 41 complete! Average Loss: 100.3753, Recon: 81.5746, KL: 18.8006


Epoch 42: 100%|██████████| 600/600 [00:15<00:00, 38.41it/s, kl=18.8, loss=103, recon=84.3] 


	Epoch 42 complete! Average Loss: 100.2579, Recon: 81.4511, KL: 18.8069


Epoch 43: 100%|██████████| 600/600 [00:18<00:00, 33.30it/s, kl=18.9, loss=103, recon=83.9] 


	Epoch 43 complete! Average Loss: 100.1928, Recon: 81.3616, KL: 18.8312


Epoch 44: 100%|██████████| 600/600 [00:14<00:00, 40.25it/s, kl=19.1, loss=103, recon=83.5] 


	Epoch 44 complete! Average Loss: 100.1275, Recon: 81.2945, KL: 18.8331


Epoch 45: 100%|██████████| 600/600 [00:19<00:00, 30.75it/s, kl=19.2, loss=106, recon=86.4] 


	Epoch 45 complete! Average Loss: 100.0116, Recon: 81.1918, KL: 18.8198


Epoch 46: 100%|██████████| 600/600 [00:16<00:00, 35.94it/s, kl=18.8, loss=99.2, recon=80.4]


	Epoch 46 complete! Average Loss: 99.9493, Recon: 81.0904, KL: 18.8589


Epoch 47: 100%|██████████| 600/600 [00:19<00:00, 31.50it/s, kl=19.1, loss=101, recon=81.8] 


	Epoch 47 complete! Average Loss: 99.8894, Recon: 81.0438, KL: 18.8456


Epoch 48: 100%|██████████| 600/600 [00:16<00:00, 35.53it/s, kl=18.4, loss=96.2, recon=77.8]


	Epoch 48 complete! Average Loss: 99.8273, Recon: 80.9730, KL: 18.8543


Epoch 49: 100%|██████████| 600/600 [00:15<00:00, 37.53it/s, kl=19, loss=97.7, recon=78.6]  


	Epoch 49 complete! Average Loss: 99.7953, Recon: 80.9025, KL: 18.8928


Epoch 50: 100%|██████████| 600/600 [00:19<00:00, 31.26it/s, kl=18.2, loss=94.5, recon=76.3]

	Epoch 50 complete! Average Loss: 99.7277, Recon: 80.8546, KL: 18.8731
Finish!!





## Result visualization:

We provide a simple overview of your results. The top row shows the input image and the bottom row shows the images the network reconstructs. Feel free to try create other meaningful visualisations, as it tests your ability to check the performance of your models!

In [1]:
import matplotlib.pyplot as plt

def plot_reconstructions(model, data_loader, device, n=10):
    model.eval()
    x, _ = next(iter(data_loader))
    x = x.to(device).view(-1, 28*28)

    with torch.no_grad():
        x_hat, _, _ = model(x)

    x = x.view(-1, 1, 28, 28).cpu()
    x_hat = x_hat.view(-1, 1, 28, 28).cpu()

    fig, axes = plt.subplots(2, n, figsize=(2*n, 4))
    for i in range(n):
        axes[0, i].imshow(x[i][0], cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(x_hat[i][0], cmap="gray")
        axes[1, i].axis("off")
    plt.show()

plot_reconstructions(model, test_loader, DEVICE)

NameError: name 'model' is not defined