<a href="https://colab.research.google.com/github/MatchLab-Imperial/deep-learning-course/blob/master/06_Autoencoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Representation Learning

In this notebook we will train a network that will produce a good feature representation for the MNIST images. In this case, however, we will not use the labels available in the dataset to learn a label predictor $g$. We will approach the problem in an unsupervised manner by using the autoencoder method explained in the lecture. Hence, given an image $x$ from MNIST, we will encode it (using an encoder $\phi$) to a lower dimensionality vector $\phi(x)$, for example using only 2 values. Then, we will decode the lower dimensionality vector using a decoder $\psi$, and then minimize a reconstruction error. In this notebook, we will use the $l_2$-norm as reconstruction metric: we will minimize $||x - \psi(\phi(x))||^2$. Usually, the encoder $\phi$ and the decoder $\psi$ have a mirrored architecture. The following figure from the slides summarizes the idea.

<a href="https://ibb.co/0hhQ3Zt"><img src="https://i.ibb.co/6ggN5bB/Screenshot-from-2019-02-14-14-49-24.png" alt="Screenshot-from-2019-02-14-14-49-24" border="0"></a>

In [None]:
!pip install torchinfo

In [None]:
import os
import time

import cv2
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from skimage.transform import resize
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchinfo
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset
from tqdm import tqdm


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def set_seed(seed: int) -> None:
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    num_epochs: int = 10,
    val_loader: DataLoader = None,
    device: torch.device = DEVICE,
):
    model = model.to(DEVICE)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)
        train_loss /= len(train_loader.dataset)

        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    y_pred = model(x)
                    loss = criterion(y_pred, y)
                    val_loss += loss.item() * x.size(0)
            val_loss /= len(val_loader.dataset)
            print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        else:
            print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}")

def evaluate(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device = DEVICE,
    desc: str = "Test"
) -> float:
    model = model.to(DEVICE)
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / len(data_loader.dataset)
    print(f"{desc} Loss: {avg_loss:.4f}")

**Loading the dataset**

As usual, we load the dataset and import some necessary modules. Both encoder $\phi$ and decoder $\psi$ will use `Dense` layers (with some activation functions in some cases), so we load the data in an array form in this tutorial.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Load datasets
train_dataset = datasets.MNIST(root='.', train=True, transform=transform, download=True)
test_dataset  = datasets.MNIST(root='.', train=False, transform=transform, download=True)

# Extract data and build TensorDatasets for autoencoder
x_train = torch.stack([d[0] for d in train_dataset])  # [60000, 784]
x_test  = torch.stack([d[0] for d in test_dataset])   # [10000, 784]

y_train = train_dataset.targets.numpy()
y_test = test_dataset.targets.numpy()

train_tensor_dataset = TensorDataset(x_train, x_train)
train_tensor_dataset, val_tensor_dataset = random_split(
    train_tensor_dataset,
     [50000, 10000],
    generator=torch.Generator().manual_seed(42)
)  # Extract val dataset from the train set
test_tensor_dataset  = TensorDataset(x_test, x_test)

## Linear Autoencoder
In our first approach we do not use any activation function, hence the model is completely linear. We encode the input flattened image of dimensionality $784$ ($28\times28$) in a vector of 2 dimensions, and then decode it back to the $784$ vector. Even though the design autoencoder is of low capacity and the representation is constrained to dimensionality 2, we hope to see some meaningful features. To build our linear autoencoder, we first design a linear model using a couple of `Linear` layers.

In [None]:
class LinearAutoencoder(nn.Module):
    def __init__(self):
        super(LinearAutoencoder, self).__init__()
        self.encoder = nn.Linear(784, 2)
        self.decoder = nn.Linear(2, 784)

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

    def encode(self, x):
        return self.encoder(x)

In [None]:
set_seed(42)

model = LinearAutoencoder()
print(torchinfo.summary(model, input_size=(1, 784)))
print()

train(
    model,
    train_loader = DataLoader(train_tensor_dataset, batch_size=128, shuffle=True),
    val_loader = DataLoader(val_tensor_dataset, batch_size=128),
    criterion = nn.MSELoss(),
    optimizer = optim.Adam(model.parameters(), lr=1e-3),
    num_epochs = 10,
)

Now, we will see if we have learnt a good representation of the images by using this linear autoencoder. First, let's check the MSE in the test set.



In [None]:
evaluate(model, DataLoader(test_tensor_dataset, batch_size=128), nn.MSELoss())

We have an MSE of around 0.058 in the test set. It may be hard to know if this value means a good reconstruction or not, but we can use the value to compare the performance of our linear autoencoder to other models. To have a better understanding of what this value means qualitatively, let's plot some images, along with the corresponding MSE.

In [None]:
def plot_recons_original(image, label, model, size_image=(28, 28)):
    if isinstance(image, np.ndarray):
        image = torch.tensor(image, dtype=torch.float32)

    if image.ndim == 1:
        image = image.unsqueeze(0)  # [784] → [1, 784]
    elif image.ndim == 2 and image.shape == size_image:
        image = image.view(1, -1)   # [28, 28] → [1, 784]

    model = model.to(DEVICE)
    model.eval()
    with torch.no_grad():
        reconst = model(image.to(DEVICE))

    # Back to CPU numpy for plotting
    image_np = image.cpu().numpy().reshape(size_image)
    reconst_np = reconst.cpu().numpy().reshape(size_image)

    # Compute MSE
    mse = np.mean((image_np - reconst_np) ** 2)

    # Plot original vs reconstruction
    plt.subplots(1, 2)
    ax = plt.subplot(121)
    plt.imshow(image_np, cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Original (label={label})")

    ax = plt.subplot(122)
    plt.imshow(reconst_np, cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Reconstruction\nMSE: {mse:.2f}")

    plt.show()

# Select a random index from the test set
idx = np.random.randint(x_test.shape[0] - 1)
plot_recons_original(x_test[idx], y_test[idx], model)

We see how the reconstructed images look in most cases blurry, due to the autoencoder $\psi(\phi(x))$ not having enough representation capacity. In Task 1, you will try to improve the model by adding some non-linearities/more layers. Now, let's check how the feature vectors of the simple linear autoencoder we just trained are distributed in the representation space.

## Representation Space
If we have a good feature representation, we hope to see in the representation space clusters of images sharing visual similarities in the image space. For example, images from the same class should be close to each other as they are visually similar.

To check how the features are distributed in the representation space, we forward the images from the test set through the encoder part of our PyTorch model to retrieve the representation $\phi(x)$. As the vector only has dimensionality 2, we do not need to apply any dimensionality reduction technique to plot it.

We first compute the representation of the images from the MNIST test set using the `predict_representation` function. Since our encoder is implemented as a separate layer in the `LinearAutoencoder` class, we can directly call `model.encode(x)` to get the representation vectors.

In [None]:
def predict_representation(model, data, batch_size=256):
    """
    Compute representation vectors φ(x) using the encoder part of the model.
    """
    model = model.to(DEVICE)
    model.eval()
    representations = []

    data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch in data_loader:
            encoded = model.encode(batch.to(DEVICE))
            representations.append(encoded.cpu().numpy())

    return np.concatenate(representations, axis=0)

# Compute representations for the test set
representation = predict_representation(model, x_test)

Now, we plot the computed representation for the test images in a 2D scatter plot. We also assign in the scatter plot different colours to the points depending on their class label.

In [None]:
def plot_representation_label(representation, labels, plot3d=0):
  ## Function used to plot the representation vectors and assign different
  ## colors to the different classes

  # First create the figure
  fig, ax = plt.subplots(figsize=(10,6))
  # In case representation dimension is 3, we can plot in a 3d projection too
  if plot3d:
    ax = fig.add_subplot(111, projection='3d')

  # Check number of labels to separate by colors
  n_labels = labels.max() + 1
  # Color map, and give different colors to every label
  cm = plt.get_cmap('gist_rainbow')
  ax.set_prop_cycle(color=[cm(1.*i/(n_labels)) for i in range(n_labels)])
  # Loop is to plot different color for each label
  for l in range(n_labels):
    # Only select indices for corresponding label
    ind = labels == l
    if plot3d:
      ax.scatter(representation[ind, 0], representation[ind, 1],
                 representation[ind, 2], label=str(l))
    else:
      ax.scatter(representation[ind, 0], representation[ind, 1], label=str(l))
  ax.legend()
  plt.title('Features in the representation space with corresponding label')
  plt.show()
  return fig, ax

plot_representation_label(representation, y_test)

The scatter plot shows that our linear autoencoder generates representations that are mostly clustered by class. However, some of the classes seem to overlap with each other, and some classes are more clustered than others. For example, the class `1`is well clustered, but the class `9` seems to overlap greatly with the class `7`.  The class `8` is also quite spread out, with some of the features closer to the features of the class `1`, some others closer to the class `6` or `0`.

## Clustering the Data
We have learnt a linear autoencoder capable of producing meaningful representations, and we also plotted it in the last section. We saw how the representations for the different images sharing the same label were close to each other without actually using the label information in the training process. Now, let's do a quick experiment to see how well clustered are the features. We will use a clustering method on these features, and then we will assign each of the clusters to the majority class (i.e., the most represented class in the cluster). We want to check what kind of accuracy we would have using this simple classification method as a way to understand how well clustered they are.

First, let's now cluster the features using a standard technique called K-Means. A good guess is to use the same number of clusters as classes, in this case we use 10 clusters. Then, as we mentioned, assign each of the clusters to the majority class and we compute the accuracy of this method.

In [None]:
def cluster_plot_data(representation, use_gmm=0, random_state=42, plot=1):
  from sklearn.cluster import KMeans
  from sklearn import mixture


  # Set number of clusters to 10
  n_clusters = 10
  if use_gmm:
    c_pred = mixture.GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=random_state).fit_predict(representation)
  else:
    c_pred = KMeans(n_clusters=n_clusters, random_state=random_state).fit_predict(representation)
  if plot:
    _, ax = plt.subplots(1,1)
    # Color map, and give different colors to every label
    cm = plt.get_cmap('gist_rainbow')
    ax.set_prop_cycle(color=[cm(1.*i/(n_clusters)) for i in range(n_clusters)])
    # Loop is to plot different color for each label
    for c in range(n_clusters):
      # Only select indices for corresponding label
      ind = c_pred == c
      ax.scatter(representation[ind, 0], representation[ind, 1], label=str(c))
    ax.legend()
    plt.title('Clustered features in the representation space')
    plt.show()
  return c_pred

c_pred = cluster_plot_data(representation)

# Compute accuracy by checking cluster by cluster the majority class and
# assigning all of the data points in that cluster to the majority class
# Then we check the accuracy of doing so
correct = 0
for i in range(10):
  indices_c_pred = c_pred == i
  classes = y_test[indices_c_pred]
  counts = np.bincount(classes)
  class_max = np.argmax(counts)
  correct += (classes == class_max).sum()

print('Accuracy: {:.3f}'.format(correct/(1.0*y_test.shape[0])))

Even though it is not a good way to classify the samples compared to the supervised case that you have done in past tutorials, the accuracy is still quite higher than random guess, which would be around ~10%. An autoencoder of higher capacity (e.g. using more layers, including activation functions or increasing the size of the representation layer) would likely produce representations with better accuracy using this clustering method.

## Relationship to PCA

When introducing autoencoders, it is usually mentioned that a linear autoencoder is closely related to PCA, i.e. using an autoencoder with representation dimensionality of $d$ is equivalent to finding the first $d$ principal components. Let's test it by computing the PCA of the train set, and applying it to the test set. Then we will check the representation space plot, and the reconstruction MSE and check if they are consistent with what we obtained with the simple linear autoencoder we just trained.

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca = pca.fit(x_train)
representation_pca = pca.transform(x_test)
plot_representation_label(representation_pca, y_test)
plt.show()

We can see that the distribution of the PCA representations (size and shape of the clusters, and relative positions to the rest of clusters) is similar to the one obtained in the linear autoencoder section.

We now check if the MSE in the test set is similar to the linear autoencoder we trained before.

In [None]:
reconst_test = pca.inverse_transform(representation_pca)
mse_pca = ((torch.tensor(reconst_test) - x_test)**2).mean()

print('PCA Test Loss {:.4f}'.format(mse_pca))

The reconstruction error we get is almost the same as the one obtained with the autoencoder case, which seems to validate the claim that a linear autoencoder behaves in a similar way to PCA.

## Detecting Anomalies

After checking the representation of the MNIST images, let's focus on how to spot anomalies in the data. The anomalies are the samples that deviate from the usual distribution. Given a data point from a distribution 'B', the reconstruction error will be higher when using an autoencoder that was trained in a distribution 'A'. Anomaly detection has several uses, for example it can be used for quality control or to detect bank fraud.

We now will use that fact to detect anomalies. To do so, we will use the Extended MNIST (EMNIST) dataset, which, apart from digits, includes also both lowercase and uppercase characters from 'a' to 'z'. These 'a' to 'z' characters will act as the anomalies that we aim to detect. As our autoencoder was trained only using 0-9 digits, it should have higher reconstruction errors for those lowercase and uppercase characters, which will help us detect them as anomalies.

Let's start by downloading and loading the dataset. The following piece of code downloads the data and uncompresses some necessary files.

In [None]:
!wget https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip -q
!unzip -qq ./gzip.zip
!mv gzip data
# We need to unzip a couple of files with the train
# labels and images
!gunzip ./data/emnist-byclass-test-images-idx3-ubyte.gz
!gunzip ./data/emnist-byclass-test-labels-idx1-ubyte.gz
# We also install a package to help us
!pip install python-mnist

Now that we have downloaded the data, we will use the module `mnist` to load the images in the same format as the regular MNIST dataset.

In [None]:
from mnist import MNIST

# Images in folder data
mndata = MNIST('data')

# This will load the test data from the downloaded files
emnist_x_test, emnist_y_test = mndata.load('./data/emnist-byclass-test-images-idx3-ubyte',
                               './data/emnist-byclass-test-labels-idx1-ubyte')


# Convert data to numpy arrays and normalize images to the interval [0, 1]
n_elem = len(emnist_x_test)
emnist_x_test = np.array(emnist_x_test).reshape(n_elem,28,28).transpose(0,2,1).reshape(n_elem,28**2) / 255.0
emnist_y_test = np.array(emnist_y_test)

# Get labels mapping (index in emnist_y_test to character value)
emnist_labels = map(lambda x: x.strip('\r').split(' '), open('./data/emnist-byclass-mapping.txt').read().strip().split('\n'))
emnist_labels = dict(emnist_labels)

# This function will be useful to display the actual label, which is given as
# an ascii value (https://en.wikipedia.org/wiki/ASCII) instead of characters
def label_to_char(label):
  ascii_val = emnist_labels[str(label)]
  return chr(int(ascii_val))

Here we just plot a random image from the EMNIST dataset. Each time you run it, you get a random image from the dataset, in case you want to check how the different characters look.

In [None]:
ind_plot = np.random.randint(emnist_x_test.shape[0]-1)
_, ax = plt.subplots(1,1)
plt.imshow(emnist_x_test[ind_plot].reshape(28, 28), cmap='gray')
ax.set_xticks([])
ax.set_yticks([])
# We use the label_to_char function to plot the actual character in the
# figure title
plt.title(label_to_char(emnist_y_test[ind_plot]))
plt.show()

Now we have loaded the EMNIST dataset. The images contain characters that the autoencoder has not seen before, hence the reconstruction error (we use Mean Squared Error as the reconstruction metric) for the EMNIST dataset should be higher. We will first compute the reconstruction error for both the MNIST test data and the EMNIST data. Then, we compare the distribution of reconstruction error in both sets using a histogram visualization.

In [None]:
def get_reconstruction_and_mse(model, data, batch_size):
    """
    Run reconstruction on data and compute per-sample MSE entirely in torch.
    """
    data_tensor = data if isinstance(data, torch.Tensor) else torch.tensor(data, dtype=torch.float32)
    loader = DataLoader(data_tensor, batch_size=batch_size, shuffle=False)

    reconstructions = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(DEVICE)
            recon = model(batch)
            reconstructions.append(recon.cpu())

    # Concatenate all reconstructions into one tensor
    recon_all = torch.cat(reconstructions, dim=0)

    # Flatten inputs and reconstructions
    data_flat = data_tensor.view(data_tensor.size(0), -1)
    recon_flat = recon_all.view(recon_all.size(0), -1)

    # Per-sample MSE (tensor on CPU)
    mse = ((recon_flat - data_flat) ** 2).mean(dim=1)

    return recon_all, mse


def compute_errors(model, mnist_data, emnist_data, batch_size=256):
    """
    Compute and plot MSE distributions for MNIST vs EMNIST.
    """
    model = model.to(DEVICE)
    model.eval()

    # Compute for MNIST
    reconst_mnist, mse_mnist = get_reconstruction_and_mse(model, mnist_data, batch_size)
    # Compute for EMNIST
    reconst_emnist, mse_emnist = get_reconstruction_and_mse(model, emnist_data, batch_size)

    # Convert only the MSE tensors to numpy for plotting
    plt.hist(mse_mnist.numpy(), bins=20, label='MNIST', alpha=0.5, density=True)
    plt.hist(mse_emnist.numpy(), bins=20, label='EMNIST', alpha=0.5, density=True)
    plt.xlabel('MSE')
    plt.title('Distribution of MSE for MNIST/EMNIST')
    plt.legend()
    plt.show()

    return reconst_mnist, mse_mnist, reconst_emnist, mse_emnist

reconst_mnist, mse_mnist, reconst_emnist, mse_emnist = compute_errors(model, x_test, emnist_x_test)

As we expected, the distribution of errors for the EMNIST dataset has a higher mean and variance compared to the original MNIST. Let's now check how the image with the lowest reconstruction error and the image with the highest reconstruction error look.

In [None]:
# Most anomalous (highest error)
plot_recons_original(emnist_x_test[mse_emnist.argmax()], emnist_y_test[mse_emnist.argmax()], model)

# Least anomalous (lowest error)
plot_recons_original(emnist_x_test[mse_emnist.argmin()], emnist_y_test[mse_emnist.argmin()], model)

It seems that the model is not capable of representing images from the EMNIST dataset that deviate too much from those in the MNIST dataset, as for example images representing the letter 'm'. However, the autoencoder behaves well for images that are similar to those from MNIST, which is what we expected.

We now perform an anomaly detection exercise. We set the threshold $\tau$, which we use to mark a data point as an anomaly if $||x-\psi(\phi(x))||^2>\tau$, as $\tau = \mu + 2\sigma$, where $\mu$ is the average MSE for the whole MNIST test set, and $\sigma$ is the standard deviation of the MSE in the MNIST test set. We will now plot the ratio of data points marked as anomalies for each of the classes in both MNIST and EMNIST. Additionally, we also report the average MSE per class in both datasets. We expect to see that the number of MNIST anomalies is quite lower compared to the number of anomalies in the EMNIST dataset, as our autoencoder has been trained with MNIST data.

In [None]:
def print_mse_anomalies(mse_mnist, mse_emnist, emnist_labels_dict, th=-1):
    """
    This function uses the mse per class in both MNIST and EMNIST to compute the ratio of anomalies per class (anomalies/examples per class)
    We plot anomalies and mse in a bar histogram, ordered by ratio of anomalies
    Do not worry too much about how it works.
    """
    if th == -1:
        th = mse_mnist.mean() + 2 * mse_mnist.std()

    mnist_labels = range(10)
    save_array = []

    # MNIST
    for label in mnist_labels:
        indices_class = y_test == int(label)
        if indices_class.sum() == 0:
            continue
        mse_class = mse_mnist[indices_class]
        ratio_anom = (mse_class > th).sum() / indices_class.sum()
        save_array.append([mse_class.mean(), ratio_anom, str(label), 'salmon'])

    # EMNIST
    for label in emnist_labels_dict:
        indices_class = emnist_y_test == int(label)
        if indices_class.sum() == 0:
            continue
        mse_class = mse_emnist[indices_class]
        ratio_anom = (mse_class > th).sum() / indices_class.sum()
        char_label = label_to_char(label)
        save_array.append([mse_class.mean(), ratio_anom, char_label, 'steelblue'])

    # Sort by anomaly ratio
    save_array = sorted(save_array, key=lambda x: x[1])

    mse_class = [x[0] for x in save_array]
    anomalies_class = [x[1] for x in save_array]
    labels_names = [x[2] for x in save_array]
    colors = [x[3] for x in save_array]

    # Plotting
    plt.subplots(2, 1, figsize=(12, 8))

    plt.subplot(2, 1, 1)
    bars = plt.bar(range(len(save_array)), anomalies_class, color=colors)
    plt.xticks(range(len(save_array)), labels_names, rotation=90)
    plt.legend([bars[0], bars[-1]], ['MNIST', 'EMNIST'], loc='upper center')
    plt.title('Anomalies per class')
    plt.ylabel('Ratio anomalies')

    plt.subplot(2, 1, 2)
    plt.bar(range(len(save_array)), mse_class, color=colors)
    plt.xticks(range(len(save_array)), labels_names, rotation=90)
    plt.legend([bars[0], bars[-1]], ['MNIST', 'EMNIST'], loc='upper center')
    plt.title('Average MSE per class')
    plt.ylabel('MSE')

    plt.tight_layout()
    plt.show()
print_mse_anomalies(mse_mnist, mse_emnist, emnist_labels)

The top figure shows the ratio of anomalies per class, which is the number of anomalies detected for the data points of that class divided by the number of examples of that specific class. The bottom figure shows the average error per category, where we can see that the classes from EMNIST have a higher reconstruction error than those from MNIST. The shared classes in MNIST and EMNIST, the digits 0-9, seem to have a different distribution in those two datasets as the reconstruction error for those classes in EMNIST is higher than in MNIST. Additionally, some letters that look like the number 1 (`l, I, i, t, j`) or like 0 (`O` and `o`) have a low reconstruction error as they are similar to some of the images in the MNIST dataset. Some other classes, such as `W`, `w`, or `Q` have a high reconstruction error due to being quite dissimilar to any of the images in MNIST.

We have seen that data from other datasets will have a high reconstruction error. However, even for MNIST, we can look for samples that have a high reconstruction error compared to the average. We expect to see that most images in a given class will follow a similar distribution, however some of them will deviate from this distribution.

In that direction, let's take one of the classes from MNIST and plot some images with high reconstruction error and some others with low reconstruction error. We expect to see those with low reconstruction error to be quite similar between them and represent what the average sample of that class looks like. The high reconstruction error samples will, in turn, contain some elements (pose, shape for example).


In [None]:
# Set same threshold we defined before \mu + 2\sigma
th = mse_mnist.std() + mse_mnist.mean()

# We will use the label 8 for the example
indices_class = y_test == 8

# Compute the number of anomalies (mse > th)
anomalies_im = x_test[indices_class][mse_mnist[indices_class] > th]

# For the nonanomalies, we will sort the images by reconstruction error
# So we will plot the images with low reconstruction error
indices_sort = np.argsort(mse_mnist[indices_class])
nonanomalies_im = x_test[indices_class][indices_sort]

def plot_grid(images, N=5, title=''):
  ## Plots data in variable images in a grid of N*N
  # Create figure
  fig, axes = plt.subplots(N,N, figsize=(8,8))
  # Loop to generate grid
  for row in range(N):
    for col in range(N):
      idx = row+N*col
      axes[row,col].imshow(images[idx].reshape(28,28), cmap='gray')
      axes[row,col].set_xticks([])
      axes[row,col].set_yticks([])

  # Adjust white space
  fig.subplots_adjust(hspace=0.0)
  fig.subplots_adjust(wspace=0.0)
  fig.subplots_adjust(right=1.0)
  fig.subplots_adjust(left=0.245)
  # Set title
  fig.suptitle(title, x=0.62, y=0.93, fontsize=24)

# Plot anomalies
plot_grid(anomalies_im, title='Anomalies')
# Plot non-anomalies
plot_grid(nonanomalies_im, title='Non-anomalies')

The images classified as anomalies for the given threshold $\tau$ show a much larger variation in shape and pose. In turn, the images with low reconstruction error for the selected label (in this case `8`) show similar proportions and less variation between any two images.

# Image-to-Image Translation

As seen previously, autoencoders aim at creating low-dimensionality representations that only preserve the most discriminative information of the input signals. We can train them by looking into the reconstruction error between the input and the generated data. This training process is a form of *self-supervision* learning. Thus, the encoding part is trained to create a strong feature representation, from which the decoder can reconstruct the original data.

The dimensionality reduction/compression is widely used in many signal processing tasks. However, autoencoders are also applied in many image-to-image tasks. Here we will show how to train an autoencoder as a denoising method.

## Denoising with Autoencoders

The idea behind this task is that autoencoders will preserve only the characteristic information and avoid noise on the reconstruction.

Noise is random, and it can not be predicted. Therefore, a robust encoding/decoding learns to identify the stable patterns within the image, and avoid the non-smooth perturbations produced by the noise.

This idea is in line with the **Efficient Coding Hypothesis** from the vision neuroscience field, which says ". . . the Efficient Coding Hypothesis holds that the purpose of early visual processing is to produce an efficient representation of the incoming visual signal."

Noise can be presented in many different ways in images. In this tutorial, we will generate synthetically additive Gaussian noise to train our denoising autoencoder.

In [None]:
from torchvision.transforms.functional import to_pil_image

# Transform for loading original (clean) data
transform_clean = transforms.Compose([
    transforms.ToTensor(),  # Converts to [0,1] range
])

# Download and load data
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_clean)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_clean)

# Convert to numpy arrays for easier noise addition
x_train = torch.stack([img for img, _ in train_dataset]).numpy().transpose(0, 2, 3, 1)  # (N, H, W, C)
x_test = torch.stack([img for img, _ in test_dataset]).numpy().transpose(0, 2, 3, 1)

# Additive Gaussian noise
np.random.seed(42)
x_train_noise = x_train + np.random.normal(scale=0.3, size=x_train.shape)
x_test_noise = x_test + np.random.normal(scale=0.3, size=x_test.shape)

# Clip values to [0, 1] after adding noise
x_train_noise = np.clip(x_train_noise, 0., 1.)
x_test_noise = np.clip(x_test_noise, 0., 1.)

We can visualise some examples to have an idea of how the noise looks like:

In [None]:
from torchvision.transforms import Resize, ToTensor, Compose

resize_transform = Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

def resize_batch(imgs):
    resized = []
    for i in range(imgs.shape[0]):
        resized_img = resize_transform((imgs[i] * 255).astype(np.uint8))
        resized.append(resized_img.permute(1, 2, 0).numpy())  # (H, W, C)
    return np.stack(resized)

# Resize first 4 images from clean and noisy sets
x_train_resized = resize_batch(x_train[:4])
x_train_noise_resized = resize_batch(x_train_noise[:4])

# Visualize side-by-side comparison
N = 2
start_val = 0
fig, axes = plt.subplots(N, N, figsize=(8, 8))
plt.suptitle('Clean vs Noisy', fontsize=18)

for row in range(N):
    for col in range(N):
        idx = start_val + row + N * col
        im = np.concatenate((x_train_resized[idx], x_train_noise_resized[idx]), axis=1)
        axes[row, col].imshow(np.clip(im, 0, 1))
        axes[row, col].set_xticks([])
        axes[row, col].set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
# Convert to torch tensors (transpose to NCHW)
x_train_tensor = torch.tensor(x_train_noise.transpose(0, 3, 1, 2), dtype=torch.float32)
x_test_tensor = torch.tensor(x_test_noise.transpose(0, 3, 1, 2), dtype=torch.float32)

# Self-supervised: use noisy inputs as targets
train_dataset = TensorDataset(x_train_tensor, x_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Supervised: use clean inputs as targets
x_train_clean_tensor = torch.tensor(x_train.transpose(0, 3, 1, 2), dtype=torch.float32)
x_test_clean_tensor = torch.tensor(x_test.transpose(0, 3, 1, 2), dtype=torch.float32)

train_dataset_supervised = TensorDataset(x_train_tensor, x_train_clean_tensor)
train_loader_supervised = DataLoader(train_dataset_supervised, batch_size=128, shuffle=True)

### Linear Autoencoders

As in regular architecture design, when defining the structure of an autoencoder, we can decide the number of convolutional layers, dense layers, normalisation mechanisms, activation functions, and so on.

Similar to the Representation Learning section, in this denoising task, we start with a linear architecture.

In [None]:
# Define linear autoencoder
class LinearDenoisingAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 32 * 3, 128),
            nn.Linear(128, 30)
        )
        self.decoder = nn.Sequential(
            nn.Linear(30, 128),
            nn.Linear(128, 32 * 32 * 3),
            nn.Unflatten(1, (3, 32, 32))  # (N, C, H, W)
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

We train our model with the standard Mean Square Error (MSE) loss function. We compute MSE between reconstructed images and noisy ones. Ideally, as mentioned above, the noise is random, and therefore, it will not be modelled.

Sometimes we have a dataset that contains noisy and clean images. If that is the case, we could compute the MSE between the reconstructed and the clean ones. Using clean ground-truth images (*supervised learning*) improves results compared to only using noisy ones (*self-supervised learning*). However, we want to start from the perspective of image compression and feature representation, and train our denoising model using exclusively noisy images.

In [None]:
set_seed(42)

model = LinearDenoisingAutoencoder()
print(torchinfo.summary(model, input_size=(1, 3, 32, 32)))
print()

train(
    model,
    train_loader,
    nn.MSELoss(),
    optim.Adam(model.parameters(), lr=1e-3),
)

Let's check how denoised images look like after ten epochs of training. Run the following code multiple times to see different examples:

In [None]:
# Random example from test set
idx_example = np.random.randint(0, len(x_test_tensor) - 2)
x_input = x_test_tensor[idx_example:idx_example+2].to(DEVICE)

# Run model on test example
model.eval()
with torch.no_grad():
    x_denoised = model(x_input).cpu().numpy()

# Get clean version
x_clean = x_test[idx_example:idx_example+2]
x_noisy = x_test_noise[idx_example:idx_example+2]

# Visualize Noisy vs Reconstructed vs Clean
N = 2
fig, axes = plt.subplots(N, 1, figsize=(8, 6))
plt.suptitle('Noisy Input VS Reconstructed VS Clean Image', fontsize=18)

for row in range(N):
    im = np.concatenate((
        x_noisy[row],                         # Noisy
        np.transpose(x_denoised[row], (1, 2, 0)),  # Reconstructed
        x_clean[row]                          # Clean
    ), axis=1)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].set_xticks([])
    axes[row].set_yticks([])

plt.tight_layout()
plt.show()

Reconstructed images are a much-blurred version of the clean ones. The proposed autoencoder can learn to remove noise, but results are far from being perfect when comparing them with the ground-truth images. Note though that with a few neurons, the network can generate images that have roughly the same colour and shape than the original ones.

### Non-linear Autoencoders

In non-linear autoencoders, we introduce non-linearities within the network by using non-linear activation functions. This is a big difference with classical technique PCA, which performs a linear transformation of the data. Thus, using non-linear autoencoders allows us to create flexible algorithms that can learn relationships in the data beyond linear transformations.

In [None]:
class NonlinearDenoisingAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 32 * 3, 128),
            nn.Linear(128, 30),
        )
        self.decoder = nn.Sequential(
            nn.Linear(30, 128),
            nn.ReLU(),
            nn.Linear(128, 32 * 32 * 3),
            nn.Sigmoid(),  # restrict outputs to [0, 1]
            nn.Unflatten(1, (3, 32, 32))  # Reshape to (C, H, W)
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

In [None]:
set_seed(42)

nonlinear_model = NonlinearDenoisingAutoencoder()
print(torchinfo.summary(nonlinear_model, input_size=(1, 3, 32, 32)))
print()

train(
    nonlinear_model,
    train_loader,
    nn.MSELoss(),
    optim.Adam(nonlinear_model.parameters(), lr=1e-3),
)

We can now visualise the reconstructed images by the proposed non-linear model.

In [None]:
idx_example = np.random.randint(0, len(x_test_tensor) - 2)
x_input = x_test_tensor[idx_example:idx_example+2].to(DEVICE)

# Predict
nonlinear_model.eval()
with torch.no_grad():
    x_denoised = nonlinear_model(x_input).cpu().numpy()

x_noisy = x_test_noise[idx_example:idx_example+2]
x_clean = x_test[idx_example:idx_example+2]

N = 2
fig, axes = plt.subplots(N, 1, figsize=(8, 6))
plt.suptitle('Noisy Input VS Reconstructed VS Clean Image', fontsize=18)

for row in range(N):
    im = np.concatenate((
        x_noisy[row],
        np.transpose(x_denoised[row], (1, 2, 0)),
        x_clean[row]
    ), axis=1)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].set_xticks([])
    axes[row].set_yticks([])

plt.tight_layout()
plt.show()

Qualitatively results at this stage are difficult to interpret. Hence, we compute the MSE metric in the test set for both architectures and compare them:

In [None]:
def evaluate_mse(model, noisy_input, clean_target):
    model.eval()
    with torch.no_grad():
        predictions = []
        targets = []
        for i in range(0, len(noisy_input), 128):
            batch_input = noisy_input[i:i+128].to(DEVICE)
            batch_target = clean_target[i:i+128].to(DEVICE)

            pred = model(batch_input)
            predictions.append(pred.cpu())
            targets.append(batch_target.cpu())

        predictions = torch.cat(predictions)
        targets = torch.cat(targets)
        return nn.MSELoss()(predictions, targets).item()

mse_linear = evaluate_mse(model, x_test_tensor, x_test_clean_tensor)
mse_nonlinear = evaluate_mse(nonlinear_model, x_test_tensor, x_test_clean_tensor)

print(f"Linear Model MSE: {mse_linear:.4f}")
print(f"Non-linear Model MSE: {mse_nonlinear:.4f}")

Although results are close, we can observe an improvement in MSE just by applying a non-linear activation function to the network.

### Convolutional Autoencoders

As seen in previous tutorials, 2D convolutions are more suitable when the input is an image than 1D convolutions. In practical settings, convolutional autoencoders are always applied to images, they simply perform much better.

CNNs are used exactly as before, however, now we have to take into account elements such as strides, pooling, upsampling, and deconvolutions. Those elements are needed to control the size (width and length) of the generated feature maps. A regular setting in image denoising is that the input and output dimension is constant.

Besides, we need to define the number of levels of compression within the architecture, which can be controlled by limiting the number of convolutional layers and/or downsampling steps:


![](https://i.ibb.co/vB2jq4C/autoencoders.png)

Let's define and train an architecture with 3 compression levels:



In [None]:
class CNNAutoencoder(nn.Module):
    def __init__(self):
        super(CNNAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(3, 3, kernel_size=3, padding=1),
            nn.Sigmoid() # Added Sigmoid activation to constrain output to [0, 1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In our example, our encoder uses strides of (2, 2) in the convolutions. (2, 2) stride downsamples the dimension of the feature map by two at every step. To recover the original size, in the decoder, we use UpSampling2D layers. (2, 2) UpSampling2D layers increase by two the dimensionality of the input feature map. These two mechanisms allow us to control the dimensions of the feature maps throughout the autoencoder.

Let's check how the reconstructed images look:

In [None]:
set_seed(42)

conv_model = CNNAutoencoder()
print(torchinfo.summary(conv_model, input_size=(1, 3, 32, 32)))
print()

train(
    conv_model,
    train_loader,
    nn.MSELoss(),
    optim.Adam(conv_model.parameters(), lr=1e-3),
)

In [None]:
# Predict
idx_example = np.random.randint(0, len(x_test_tensor) - 2)
x_input = x_test_tensor[idx_example:idx_example+2].to(DEVICE)

conv_model.eval()
with torch.no_grad():
    x_denoised = conv_model(x_input).cpu().numpy()

x_noisy = x_test_noise[idx_example:idx_example+2]
x_clean = x_test[idx_example:idx_example+2]

N = 2
fig, axes = plt.subplots(N, 1, figsize=(8, 6))
plt.suptitle('Noisy Input VS Reconstructed VS Clean Image', fontsize=18)

for row in range(N):
    im = np.concatenate((
        x_noisy[row],
        np.transpose(x_denoised[row], (1, 2, 0)),
        x_clean[row]
    ), axis=1)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].set_xticks([])
    axes[row].set_yticks([])

plt.tight_layout()
plt.show()

And check their MSE error:

In [None]:
mse_cnn = evaluate_mse(conv_model, x_test_tensor, x_test_clean_tensor)

print('Linear Model MSE: {:0.4f}'.format(mse_linear))
print('Non-linear Model MSE: {:0.4f}'.format(mse_nonlinear))
print('CNN Model MSE: {:0.4f}'.format(mse_cnn))

We are getting better at it!

### Denoising with Clean Images

We showed that autoencoders can encode the image's structure and recover a clean but blurry image. They ignore the non-smooth perturbations produced by the noise, allowing us to train them in a self-supervised setting.

Although having a ground-truth of clean images is not always possible, it is desirable when training denoising models. As we have generated synthetically the noise, we can train the network using the original image as the clean one:

In [None]:
class CNN_GT_Autoencoder(nn.Module):
    def __init__(self):
        super(CNN_GT_Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # -> (B, 32, H/2, W/2)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # -> (B, 64, H/4, W/4)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# -> (B, 128, H/8, W/8)
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(3, 3, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
set_seed(42)

cnn_gt_model = CNN_GT_Autoencoder()
print(torchinfo.summary(cnn_gt_model, input_size=(1, 3, 32, 32)))
print()

train(
    cnn_gt_model,
    train_loader_supervised,
    nn.MSELoss(),
    optim.Adam(cnn_gt_model.parameters(), lr=1e-3),
)

Let's now check how the reconstructed images look:

In [None]:
cnn_gt_model.eval()
idx_example = np.random.randint(0, len(x_test_noise))
with torch.no_grad():
    noisy_input = torch.tensor(x_test_noise[idx_example:idx_example+2].transpose(0, 3, 1, 2), dtype=torch.float32).to(DEVICE)
    denoised_output = cnn_gt_model(noisy_input).cpu().numpy().transpose(0, 2, 3, 1)

N = 2
start_val = 0
fig, axes = plt.subplots(N, 1)
plt.suptitle('Noisy Input VS Denoised VS Clean Image', fontsize=18)
for row in range(N):
    idx = start_val + row
    im = np.concatenate((
        x_test_noise[idx_example+idx],
        np.clip(denoised_output[idx], 0, 1),
        x_test[idx_example+idx]
    ), axis=1)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].set_xticks([])
    axes[row].set_yticks([])
plt.show()

In addition, we can compare the MSE when having clean images for training or not:

In [None]:
mse_supervised = evaluate_mse(cnn_gt_model, x_test_tensor, x_test_clean_tensor)
print('"Self"-supervised Model (Trained only with Noisy Images) MSE: {:0.4f}'.format(mse_cnn))
print('Supervised Model (Trained also with Clean Images) MSE: {:0.4f}'.format(mse_supervised))

Even though there is a clear improvement, at this point, the results do not differ that much. Thus, having clean images is not enough for denoising images. We could try to increase the model's complexity, however, we propose to use now skip connections.

### Skip Connections

Skip connections in Autoencoders were introduced in the *Common CNN Architectures* tutorial, together with [UNet](https://arxiv.org/pdf/1505.04597.pdf) network. Skip connections, as the name suggests, creates connections between layers within the neural network by jumping some of them. This connection allows the network to feed layers not only with the previous one but, in addition, with other layers that were not directly connected. For example, in the next figure we can see an architecture, where  the encoder and decoder share features through skip connections:

![texto alternativo](https://i.ibb.co/d7sX5bh/Skip-Connections.png)

The idea behind this is that there is information in the first layers that is hard to recover by posterior ones. This information is important when upsampling the feature map in the decoder, where instead of learning how to recover features from the first layers, it can learn how to directly used them to generate better outputs.

Besides, skip connections help to pass information through the network, helping to mitigate the vanishing gradient problem when the architecture is deep. As seen in the *Common CNN Architectures* tutorial, gradient information can be lost during the backpropagation when it goes through many layers. Having a direct path between encoder and decoder layers helps the convergence and training of the network.

There are many ways to implement those skip connections, we show in the next example how to generate them by concatenating features:

In [None]:
class SkipConnectionAutoencoder(nn.Module):
    def __init__(self):
        super(SkipConnectionAutoencoder, self).__init__()
        # Encoder
        self.enc1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)  # conv1
        self.enc2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) # conv2
        self.enc3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) # conv3
        self.enc4 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) # conv4

        # Decoder
        self.dec1_conv = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.dec2_conv = nn.Conv2d(128 + 128, 64, kernel_size=3, padding=1)   # merge with conv3
        self.dec3_conv = nn.Conv2d(64 + 64, 32, kernel_size=3, padding=1)     # merge with conv2
        self.dec4_conv = nn.Conv2d(32 + 32, 32, kernel_size=3, padding=1)     # merge with conv1
        self.final_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        c1 = F.relu(self.enc1(x))  # 16x16x32
        c2 = F.relu(self.enc2(c1)) # 8x8x64
        c3 = F.relu(self.enc3(c2)) # 4x4x128
        c4 = F.relu(self.enc4(c3)) # 2x2x128

        # Decoder + Skip connections
        u1 = F.interpolate(c4, scale_factor=2, mode='nearest')  # 4x4x128
        u1 = F.relu(self.dec1_conv(u1))
        m1 = torch.cat([c3, u1], dim=1)  # merge1

        u2 = F.interpolate(m1, scale_factor=2, mode='nearest')  # 8x8
        u2 = F.relu(self.dec2_conv(u2))
        m2 = torch.cat([c2, u2], dim=1)  # merge2

        u3 = F.interpolate(m2, scale_factor=2, mode='nearest')  # 16x16
        u3 = F.relu(self.dec3_conv(u3))
        m3 = torch.cat([c1, u3], dim=1)  # merge3

        u4 = F.interpolate(m3, scale_factor=2, mode='nearest')  # 32x32
        u4 = self.dec4_conv(u4)

        out = self.final_conv(u4)
        return out

Let's train the model with skip connections:

In [None]:
set_seed(42)

skip_model = SkipConnectionAutoencoder()
print(torchinfo.summary(skip_model, input_size=(1, 3, 32, 32)))
print()

train(
    nn.MSELoss(),
    optim.Adam(skip_model.parameters(), lr=1e-3),
)

In [None]:
# Evaluation
skip_model.eval()
mse = evaluate_mse(skip_model, x_test_tensor, x_test_clean_tensor)

print('Model without Skip Connections MSE: {:0.4f}'.format(mse_supervised))
print('Model with Skip Connections MSE: {:0.4f}'.format(mse))

Let's visualise how the denoised images look:

In [None]:
idx_example = np.random.randint(0, len(x_test_noise))
with torch.no_grad():
    noisy_input = torch.tensor(
        x_test_noise[idx_example:idx_example+2].transpose(0, 3, 1, 2),
        dtype=torch.float32
    ).to(DEVICE)
    denoised_output = skip_model(noisy_input).cpu().numpy().transpose(0, 2, 3, 1)

N = 2
fig, axes = plt.subplots(N, 1)
plt.suptitle('Noisy Input VS Denoised VS Clean Image', fontsize=18)
for row in range(N):
    im = np.concatenate((
        x_test_noise[idx_example+row],
        np.clip(denoised_output[row], 0, 1),
        x_test[idx_example+row]
    ), axis=1)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].set_xticks([])
    axes[row].set_yticks([])
plt.show()

Images are a bit blurry, but noise is almost all gone!

## Image Segmentation with Autoencoders

Autoencoders have been used for multiple tasks, their applications do not limit to representation learning, or image denoising. We are going to show here how to use them to do a more general task, image segmentation.

Image segmentation is the task of assigning a class value to each pixel in the input image. To do so, we have prepared the [COCO-Person](https://www.kaggle.com/oishee30/cocopersonsegmentation) dataset. COCO-Person dataset contains a collection of images and their respective masks. Their masks indicate if a pixel is background or person. Let's first of all download it:

In [None]:
# Download the Supervise.ly Filtered Segmentation Person Dataset
!wget -q https://imperialcollegelondon.box.com/shared/static/kq70816mxj4mnfph5kxt352xnx40rzeq.zip
!unzip -qq kq70816mxj4mnfph5kxt352xnx40rzeq.zip

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, images, masks):
        self.images = torch.tensor(images).permute(0, 3, 1, 2)  # NHWC -> NCHW
        self.masks = torch.tensor(masks)  # N,H,W

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.masks[idx]

def load_split(split):
    images = os.listdir(f'is_coco_dataset/{split}/')
    data, labels = [], []
    for i in images:
        img = cv2.imread(f'is_coco_dataset/{split}/{i}')
        mask = cv2.cvtColor(cv2.imread(f'is_coco_dataset/masks/{i}'), cv2.COLOR_BGR2GRAY)
        mask = np.where(mask > 100, 1, 0)  # binary mask
        data.append(img)
        labels.append(mask)
    return np.asarray(data), np.asarray(labels)

def load_data():
    print('Start loading data')
    train_data, train_labels = load_split('train')
    val_data, val_labels = load_split('val')
    test_data, test_labels = load_split('test')
    return [train_data, train_labels], [val_data, val_labels], [test_data, test_labels]

def shuffle_data(train_data, train_labels, val_data, val_labels):
    train_idx = np.random.permutation(len(train_data))
    val_idx = np.random.permutation(len(val_data))
    return train_data[train_idx], train_labels[train_idx], val_data[val_idx], val_labels[val_idx]


(x_train, y_train), (x_val, y_val), (x_test, y_test) = load_data()

x_train = x_train.astype(np.float32) / 255.
x_val = x_val.astype(np.float32) / 255.
x_test = x_test.astype(np.float32) / 255.

x_train, y_train, x_val, y_val = shuffle_data(x_train, y_train, x_val, y_val)

y_train = y_train.astype(np.int64)
y_val = y_val.astype(np.int64)
y_test = y_test.astype(np.int64)

train_dataset = SegmentationDataset(x_train, y_train)
val_dataset = SegmentationDataset(x_val, y_val)
test_dataset = SegmentationDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

We now can visualise the dataset. We plot the input images, segmented masks, and their overlapping. Run it multiple times for different examples:

In [None]:
N = 2
y_train_vis = np.where(y_train >= 0.5, 1., 0.).astype(np.float32)
y_train_mask = np.tile(y_train_vis[:, :, :, np.newaxis], [1, 1, 1, 3])
y_train_vis_gt = np.zeros((len(y_train), 128, 176, 3), np.float32)
y_train_vis_gt[:, :, :, 1] = y_train_vis
blank_space = np.ones((128, 70, 3), np.float32)

start_val = np.random.randint(len(y_train) - N**2)
fig, axes = plt.subplots(N, 1)
plt.suptitle('Input Image VS GT Mask VS Masked Image', fontsize=18)
for row in range(N):
    idx = start_val + row
    overlap_gt = cv2.addWeighted(x_train[idx], 1, y_train_vis_gt[idx], 0.5, 0)
    im = np.concatenate((x_train[idx], blank_space, y_train_mask[idx], blank_space, overlap_gt), 1)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].axis("off")
plt.show()

Segmenting the background of a portrait photo is a really popular task in daily applications. For instance, if you do not have a DSLR camera, you still can get those blurry backgrounds thanks to Autoencoders. Applications do not end there, check some examples of what you can do in the [Google AI Blog](https://ai.googleblog.com/2018/03/mobile-real-time-video-segmentation.html). Here there is one example from them:


<a href="https://ibb.co/0hhQ3Zt"><img src="https://3.bp.blogspot.com/-jp16CE_SLZk/WpW1sKWU9PI/AAAAAAAACY0/sjHghHiuarEC2aEy5txhmdT6INK9C_OxACLcBGAs/s640/image3.gif" alt="Screenshot-from-2019-02-14-14-49-24" border="0"></a>


Now that we have the data and understand the task, we can define a model and train it:

In [None]:
class AutoencoderSeg(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.enc1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
        self.enc2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.enc3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.enc4 = nn.Conv2d(128, 128, 3, stride=2, padding=1)
        self.up1 = nn.Conv2d(128, 128, 3, padding=1)
        self.up2 = nn.Conv2d(256, 64, 3, padding=1)
        self.up3 = nn.Conv2d(128, 32, 3, padding=1)
        self.up4 = nn.Conv2d(64, 32, 3, padding=1)
        self.final = nn.Conv2d(32, num_classes, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        c1 = self.relu(self.enc1(x))
        c2 = self.relu(self.enc2(c1))
        c3 = self.relu(self.enc3(c2))
        c4 = self.relu(self.enc4(c3))
        u1 = F.interpolate(c4, scale_factor=2, mode='nearest')
        u1 = self.relu(self.up1(u1))
        u1 = torch.cat([c3, u1], dim=1)
        u2 = F.interpolate(u1, scale_factor=2, mode='nearest')
        u2 = self.relu(self.up2(u2))
        u2 = torch.cat([c2, u2], dim=1)
        u3 = F.interpolate(u2, scale_factor=2, mode='nearest')
        u3 = self.relu(self.up3(u3))
        u3 = torch.cat([c1, u3], dim=1)
        u4 = F.interpolate(u3, scale_factor=2, mode='nearest')
        u4 = self.up4(u4)
        return self.final(u4)  # logits


# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoencoderSeg(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_val_acc = 0
patience, patience_counter = 4, 0

set_seed(42)

for epoch in range(30):
    model.train()
    train_correct, train_total, train_loss = 0, 0, 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        train_correct += (preds == masks).sum().item()
        train_total += masks.numel()

    val_correct, val_total, val_loss = 0, 0, 0
    model.eval()
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == masks).sum().item()
            val_total += masks.numel()

    train_acc = train_correct / train_total
    val_acc = val_correct / val_total
    print(f"Epoch [{epoch+1}/30] Train Loss: {train_loss/len(train_dataset):.4f}, "
          f"Train Acc: {train_acc*100:.2f}%, Val Loss: {val_loss/len(val_dataset):.4f}, "
          f"Val Acc: {val_acc*100:.2f}%")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter > patience:
            print("Early stopping")
            break

# Evaluate on test set
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
    for imgs, masks in test_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(dim=1)
        test_correct += (preds == masks).sum().item()
        test_total += masks.numel()
print()
print(f'Test Accuracy: {test_correct/test_total*100:.2f}%')

In addition to the previous code, we also provide code for visualising the predicted masks on the test set:

In [None]:
x_test_batch, y_test_batch = next(iter(test_loader))
x_test_batch, y_test_batch = x_test_batch.to(device), y_test_batch.to(device)
with torch.no_grad():
    preds = model(x_test_batch[:2]).argmax(dim=1).cpu().numpy()
x_test_vis = np.zeros((2, 128, 176, 3), np.float32)
x_test_vis[:, :, :, 1] = preds
y_test_vis = np.zeros((2, 128, 176, 3), np.float32)
y_test_vis[:, :, :, 1] = y_test_batch[:2].cpu().numpy()
blank_space = np.ones((128, 70, 3), np.float32)

N = 2
start_val = 0
fig, axes = plt.subplots(N, 1)
plt.suptitle('Input Image VS Predicted Mask VS GT Mask', fontsize=18)
for row in range(N):
    idx = start_val + row
    img = x_test_batch[idx].cpu().permute(1, 2, 0).numpy()
    overlap_pred = cv2.addWeighted(img, 1, x_test_vis[idx], 0.5, 0)
    overlap_gt = cv2.addWeighted(img, 1, y_test_vis[idx], 0.5, 0)
    im = np.concatenate((img, blank_space, overlap_pred, blank_space, overlap_gt), 1)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    axes[row].imshow(np.clip(im, 0, 1))
    axes[row].axis("off")
plt.show()

You could run your segmentation network on internet images and see how it performs there. You could even try to do the background blurring effect we introduced above!

# Coursework

## Task 1: Non-linear Transformations for Representation Learning

PCA is a standard dimensionality reduction technique that uses a linear transformation. In this task we are going to define two autoencoders, one convolutional and one without using any convolutional layer, that are capable of learning a non-linear transformation to reduce the dimensionality of the input MNIST image, and we will compare those autoencoders to PCA. A way to evaluate the quality of the representations produced by both PCA and the autoencoder is to learn a classifier on top of those representations with reduced dimensionality. If the classifier has high accuracy, then the representations can be considered meaningful. In our case, we will use representations with dimensionality 10 and we will use those representations to train a linear classifier, which is defined in the code below.

The given example architectures for both the non-convolutional and the convolutional autoencoder already produce, after training them, representations of similar quality as PCA. Modify the given architectures and try to increase the accuracy when training a linear classifier on top of the autoencoder representations. The code given below may help you understand the pipeline.

As in past notebooks, treat the MNIST test set as your validation set. You can use any of the layers and techniques presented in past notebooks, the only constraints are that the non-convolutional autoencoder should not have any Conv2d layer, that the convolutional autoencoder should include Conv2d layers, and that the representation vector should have dimensionality 10.

**Report**:
* Table with the accuracy of `classifier` (defined below) obtained with the representations from your two proposed  autoencoder architectures (non-convolutional and convolutional autoencoder) and also with PCA with 10 components in the training set and the validation set. Additionally, include in the table the MSE error in both training and validation set for your non-convolutional autoencoder, for your convolutional autoencoder and for the PCA method. State clearly your two final autoencoder architectures and discuss the results.

We will use MNIST for this task. First, we resize all the images to have a resolution of 32x32, which will make the definition of the convolutional autoencoder easier.

In [None]:
# Fix seed for reproducibility
set_seed(42)

# Load MNIST and preprocess: normalize and resize to 32x32 with 1 channel
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

def resize_dataset(dataset):
    imgs = []
    labels = []
    for img, label in dataset:
        # img shape (1,28,28), convert to numpy
        img_np = img.numpy().transpose(1,2,0)  # (28,28,1)
        img_resized = resize(img_np, (32,32,1), anti_aliasing=True).astype(np.float32)
        imgs.append(img_resized)
        labels.append(label)
    imgs = np.array(imgs)
    labels = np.array(labels)
    return imgs, labels

x_train_32, y_train = resize_dataset(train_dataset)
x_test_32, y_test = resize_dataset(test_dataset)

# Convert numpy arrays to torch tensors
x_train_32 = torch.tensor(x_train_32).permute(0,3,1,2)  # (N,1,32,32)
x_test_32 = torch.tensor(x_test_32).permute(0,3,1,2)
y_train = torch.tensor(y_train).long()
y_test = torch.tensor(y_test).long()

train_loader = DataLoader(TensorDataset(x_train_32, x_train_32), batch_size=128, shuffle=True)
test_loader = DataLoader(TensorDataset(x_test_32, x_test_32), batch_size=128)

You can modify the code below to define your non-convolutional autoencoder. You can use any layer you want apart from `Conv2D` layers for this autoencoder.

In [None]:
### Non-convolutional Autoencoder ###
class NonConvAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.encoder = nn.Sequential(
            nn.Linear(32*32, 10),  # Representation with dimensionality 10
        )
        self.decoder = nn.Sequential(
            nn.Linear(10, 32*32),
            nn.Unflatten(1, (1, 32, 32)),
        )

    def forward(self, x):
        x = self.flatten(x)
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon

    def encode(self, x):
        x = self.flatten(x)
        z = self.encoder(x)
        return z

In [None]:
set_seed(42)

autoencoder = NonConvAutoencoder()
print(torchinfo.summary(autoencoder, input_size=(1, 1, 32, 32)))
print()

train(
    autoencoder,
    train_loader,
    nn.MSELoss(),
    optim.Adam(autoencoder.parameters()),
    num_epochs = 20,
)

You can modify the code below to define your convolutional autoencoder. For this autoencoder you need to include `Conv2D` layers in your design, but you can use any other layer too. We show an example of a simple convolutional architecture below

In [None]:
### Convolutional Autoencoder ###
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # output: 32x16x16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32*16*16, 10)  # representation 10 dim
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(10, 1*16*16),
            nn.Unflatten(1, (1,16,16)),
            nn.Upsample(scale_factor=2, mode='nearest'),  # upsample to 32x32
        )

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon

    def encode(self, x):
        return self.encoder(x)

In [None]:
# You can modify the number of epochs or other hyperparameters
set_seed(42)

conv_autoencoder = ConvAutoencoder()
print(torchinfo.summary(conv_autoencoder, input_size=(1, 1, 32, 32)))
print()

train(
    conv_autoencoder,
    train_loader,
    nn.MSELoss(),
    optim.Adam(conv_autoencoder.parameters()),
    num_epochs = 20,
)

Below you have the code you will use to train the classifier. We first extract the representations using any of the two autoencoders we just trained or PCA and then we train the classifier on top, which is just a simple Dense layer. Better representations should make it easier for the given simple classifier to separate the classes and, therefore, have larger accuracy.

In [None]:
### PCA ###
pca = PCA(n_components=10)
pca.fit(x_train_32.numpy().reshape(len(x_train_32), -1))

## We compute the representations for the different methods
representation_pca_train = pca.transform(x_train_32.numpy().reshape(len(x_train_32), -1))
representation_pca_test = pca.transform(x_test_32.numpy().reshape(len(x_test_32), -1))

# predict_representation is defined at the beginning of this notebook
representation_auto_train = predict_representation(autoencoder, x_train_32)
representation_auto_test = predict_representation(autoencoder, x_test_32)
representation_conv_auto_train = predict_representation(conv_autoencoder, x_train_32)
representation_conv_auto_test = predict_representation(conv_autoencoder, x_test_32)

# Compute PCA reconstruction MSE
reconst_train = pca.inverse_transform(representation_pca_train).reshape(-1,1,32,32)
train_mse_pca = ((reconst_train - x_train_32.numpy())**2).mean()

reconst_test = pca.inverse_transform(representation_pca_test).reshape(-1,1,32,32)
test_mse_pca = ((reconst_test - x_test_32.numpy())**2).mean()

# We print the MSE for PCA, which you need to include on the table
print(f'PCA Train MSE: {train_mse_pca:.4f}')
print(f'PCA Test MSE: {test_mse_pca:.4f}')

In [None]:
### Linear classifier to test representation quality, DO NOT MODIFY IT!
class LinearClassifier(nn.Module):
    def __init__(self, input_dim=10, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

In [None]:
def train_classifier(X_train, y_train, X_val, y_val, epochs=30):
    set_seed(42)

    X_train = torch.tensor(X_train).float()
    X_val = torch.tensor(X_val).float()

    train_ds = TensorDataset(X_train, y_train)
    val_ds = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=128)

    model = LinearClassifier(input_dim=X_train.shape[1], num_classes=10).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

        # Validation accuracy
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == yb).sum().item()
                total += yb.size(0)
        acc = correct / total

        if (epoch+1) % 10 == 0 or epoch == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Validation Accuracy: {acc*100:.2f}%")

    return acc

print("Training classifier on PCA representations")
acc_pca = train_classifier(representation_pca_train, y_train, representation_pca_test, y_test)
print()

print("Training classifier on Non-Convolutional Autoencoder representations")
acc_auto = train_classifier(representation_auto_train, y_train, representation_auto_test, y_test)
print()

print("Training classifier on Convolutional Autoencoder representations")
acc_conv_auto = train_classifier(representation_conv_auto_train, y_train, representation_conv_auto_test, y_test)
print()

The code below can help you visualize the quality of the learnt representations. tSNE is a dimensionality reduction technique that leads to nice plots, so we reduce the representations of dimensionality 10 we just learnt to dimensionality 2 via tSNE and plot it. You do not have to include the figures in the report, it is just a qualitative way for you to see the quality of your representations.

In [None]:
## We use tSNE for our dimensionality reduction technique so we can
## plot the features using a 2D plot as it leads to nice plots.
## However, tSNE is tricky to use as a general dimensionality reduction method
## for clustering due to issues mentioned here: https://distill.pub/2016/misread-tsne/
## TSNE: https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding
## Nice article explaining shortcomings: https://distill.pub/2016/misread-tsne/

## Use these parameters, the plots are highly dependent on perplexity value
tsne = TSNE(n_components=2, verbose=1, perplexity=40, max_iter=500, n_jobs=-1)
representation_tsne = tsne.fit_transform(representation_auto_test)
plot_representation_label(representation_tsne, y_test)

You can also check how the reconstructed images look with the autoencoders you just trained.

In [None]:
ind = np.random.randint(x_test.shape[0] -  1)
## The function below is defined in the tutorial
plot_recons_original(np.expand_dims(x_test_32[ind],0), y_test[ind], conv_autoencoder, size_image=(32,32))

## Task 2: Custom Loss Functions

In Image-to-Image tasks, researchers have found that an approach to improve the robustness of autoencoders is to replace the quadratic error with a loss function that is more robust to outliers.

There is a lot of interest in defining which loss function helps the most specific tasks. For instance, super-image resolution and image denoising may have different optimum loss functions, even though, they are trained in a very similar manner. Therefore, in this task, we will focus on using multiple loss functions to train an image denoise model. Sometimes, you may need to use a loss function that is not defined in Keras. If that happens, you can define it yourself and use it in the model.compile() module. We explain now how to do that.

You must define some variables:

* **True values** are those that we are aiming to generate, e.g. GT images.
* **Predicted values** are those that the network has generated,  e.g. denoised images.
* **Loss value** is the computed loss between true and predicted values,  e.g. MSE value.

The common structure for the custom loss method is as follows:

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predicted, target):
        return []

We are going to use TinyImagenet, and add synthetic noise as before. This task requires a lot of RAM, thus, before starting it, please clean your RAM memory by restarting the Colab session.

In [None]:
# download TinyImageNet
! git clone https://github.com/seshuad/IMagenet

In [None]:
class TinyImageNetNoisyDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, noise_std=0.2):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.noise_std = noise_std
        self.image_files = []

        if self.split == 'train':
            # For training data, images are in subfolders per class
            wnids_path = os.path.join(root_dir, 'tiny-imagenet-200', 'wnids.txt')
            with open(wnids_path, 'r') as f:
                classes = [line.strip() for line in f]
            for cls in classes:
                image_dir = os.path.join(root_dir, 'tiny-imagenet-200', split, cls, 'images')
                for img_file in os.listdir(image_dir):
                    if img_file.endswith('.JPEG'):
                        self.image_files.append(os.path.join(image_dir, img_file))

        elif self.split == 'val':
            # For validation data, images are in a single folder and annotations are in a file
            annotations_path = os.path.join(root_dir, 'tiny-imagenet-200', split, 'val_annotations.txt')
            with open(annotations_path, 'r') as f:
                for line in f:
                    img_name, _, _, _, _, _ = line.strip().split('\t')
                    self.image_files.append(os.path.join(root_dir, 'tiny-imagenet-200', split, 'images', img_name))

        else:
            raise ValueError(f"Invalid split: {split}. Use 'train' or 'val'.")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        noisy_img = img + torch.randn_like(img) * self.noise_std
        noisy_img = torch.clamp(noisy_img, 0, 1)

        return noisy_img, img  # return noisy input and clean target

# Image transformations (resizing and normalization)
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

# Corrected root_dir to point to the base directory of TinyImageNet
train_dataset = TinyImageNetNoisyDataset('./IMagenet', transform=transform)
test_dataset = TinyImageNetNoisyDataset('./IMagenet', split='val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

We are going to use the UNet architecture for this task, however, you could use any autoencoder of your choice.

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv5 = nn.Conv2d(512, 1024, 3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        # Decoder
        self.up6 = nn.Conv2d(1024, 512, 3, padding=1)
        self.up7 = nn.Conv2d(512, 256, 3, padding=1)
        self.up8 = nn.Conv2d(256, 128, 3, padding=1)
        self.up9 = nn.Conv2d(128, 64, 3, padding=1)

        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_final = nn.Conv2d(64, 3, 3, padding=1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Encoder
        c1 = self.relu(self.conv1(x))
        p1 = self.pool(c1)

        c2 = self.relu(self.conv2(p1))
        p2 = self.pool(c2)

        c3 = self.relu(self.conv3(p2))
        p3 = self.pool(c3)

        c4 = self.relu(self.conv4(p3))
        d4 = self.dropout(c4)
        p4 = self.pool(d4)

        c5 = self.relu(self.conv5(p4))
        d5 = self.dropout(c5)

        # Decoder
        up6 = self.up_sample(d5)
        up6 = self.relu(self.up6(up6))
        merge6 = torch.cat([d4, up6], dim=1)
        c6 = self.relu(self.up6(merge6))

        up7 = self.up_sample(c6)
        up7 = self.relu(self.up7(up7))
        merge7 = torch.cat([c3, up7], dim=1)
        c7 = self.relu(self.up7(merge7))

        up8 = self.up_sample(c7)
        up8 = self.relu(self.up8(up8))
        merge8 = torch.cat([c2, up8], dim=1)
        c8 = self.relu(self.up8(merge8))

        up9 = self.up_sample(c8)
        up9 = self.relu(self.up9(up9))
        merge9 = torch.cat([c1, up9], dim=1)
        c9 = self.relu(self.up9(merge9))

        output = self.conv_final(c9)
        output = torch.sigmoid(output)  # constrain output to [0,1]

        return output

Now, an example of how to train UNet architecture with a custom MSE loss function:

In [None]:
class CustomMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predicted, target):
        return torch.mean((predicted - target) ** 2)

In [None]:
set_seed(42)

model = UNet()
print(torchinfo.summary(model, input_size=(1, 3, 64, 64)))
print()

train(
    model,
    train_loader,
    CustomMSELoss(),
    optim.Adam(model.parameters(), lr=1e-4),
    num_epochs = 10,
)
print()

evaluate(model, test_loader, CustomMSELoss())

Use the following code for visualising examples:

In [None]:
def show_noisy_denoised_clean(idx=0):
    model.eval()
    noisy_img, clean_img = test_dataset[idx]
    noisy_img_t = noisy_img.unsqueeze(0).to(device)

    with torch.no_grad():
        denoised_img = model(noisy_img_t).cpu().squeeze(0)

    # Convert tensors to numpy arrays for plotting
    noisy_np = noisy_img.permute(1, 2, 0).numpy()
    denoised_np = denoised_img.permute(1, 2, 0).numpy()
    clean_np = clean_img.permute(1, 2, 0).numpy()

    # Plot side by side
    fig, axs = plt.subplots(1, 3, figsize=(12,4))
    axs[0].imshow(np.clip(noisy_np, 0, 1))
    axs[0].set_title('Noisy Input')
    axs[0].axis('off')
    axs[1].imshow(np.clip(denoised_np, 0, 1))
    axs[1].set_title('Denoised Output')
    axs[1].axis('off')
    axs[2].imshow(np.clip(clean_np, 0, 1))
    axs[2].set_title('Clean Target')
    axs[2].axis('off')
    plt.show()

show_noisy_denoised_clean()

Try to use a different loss function and see which one gives you the best result. Some well-known loss functions for image denoising are:

*  Structural Similarity Index ([SSIM](https://en.wikipedia.org/wiki/Structural_similarity))
*  Multiscale Structural Similarity Index ([MS-SSIM](https://ieeexplore.ieee.org/document/1292216?arnumber=1292216&tag=1))
* 1 / [PSNR](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio)
* MAE
* [L0](https://arxiv.org/pdf/1803.04189.pdf)

Check the [Noise2Noise](https://arxiv.org/pdf/1803.04189.pdf) paper to learn more about alternative losses.

**Report:**
*  In this task, you are asked to build a table containing the MSE results on the test split of models trained with different loss functions. Use two or three different loss functions from the previous list and discuss the differences you observe. Report denoised images to support your arguments. You may need to modify the UNet model definition to use some of the previous losses.

Use the following code and write your custom loss function within the provided `CustomLoss` class.



In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, predicted, target):
        '''
        Define your loss here
        '''
        return 0