# Autoencoders

This notebook has the objective of building a autoencoder to reconstruct MNIST images.

We will use MLP and convolutional auto-encoders and compare their performance. We will also apply residual connections check if it improves the performance of the image reconstruction.



In [16]:
# First lets load the nescessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# For plotting
import matplotlib.pyplot as plt

The second stage is to load the MNIST dataset, more detaisl about it are avaliable in the cnn notebook

In [17]:
# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

### Multi layer perceptron

In our first autoencoder, lets use a MLP and check its performance. We will also have a look at the latent space representation, so we will add a `forward` and a `get_encoded_representation` methods to our class. 

In [18]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12, 3)  # Encoded representation
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded  # Return both decoded and encoded values

    def get_encoded_representation(self, x):
        # This function can be used to specifically obtain the encoded representation.
        return self.encoder(x)

Now lets define the loss and the optimizer

In [19]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Autoencoder().to(device)

# other hyper parametersnum_epochs
num_epochs = 15
learning_rate = 1e-3

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# lets now begin the training loop
# Training the model
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.view(img.size(0), -1).to(device)
        
        # Forward pass
        output = model(img)
        loss = criterion(output, img)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Now that the training is finished, lets plot the results

In [None]:
dataiter = iter(test_loader)
images, labels = next(dataiter)

# Get sample outputs
images = images.view(images.size(0), -1).to(device)
output = model(images)
images = images.cpu().view(-1, 1, 28, 28)
output = output.cpu().view(-1, 1, 28, 28)

# Plot the first 10 test images and their reconstructions
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
for images, row in zip([images[:10], output[:10]], axes):
    for img, ax in zip(images, row):
        ax.imshow(img.reshape((28, 28)).detach().numpy(), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
plt.show()

## Convolution neural networks

Now, lets compare the results with a CNN

In [6]:
# CNN Autoencoder Model
class CNNAutoencoder(nn.Module):
    def __init__(self):
        super(CNNAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # [batch, 16, 14, 14]
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # [batch, 32, 7, 7]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=7)  # [batch, 64, 1, 1]
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=7),  # [batch, 32, 7, 7]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # [batch, 16, 14, 14]
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # [batch, 1, 28, 28]
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [7]:
model = CNNAutoencoder().to(device)

# other hyper parametersnum_epochs
num_epochs = 5
learning_rate = 1e-3

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.to(device)
        
        # Forward pass
        output = model(img)
        loss = criterion(output, img)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
dataiter = iter(test_loader)
images, labels = next(dataiter)

# Get sample outputs
images = images.to(device)
output = model(images)
images = images.cpu().view(-1, 1, 28, 28)
output = output.cpu().view(-1, 1, 28, 28)

# Plot the first 10 test images and their reconstructions
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
for images, row in zip([images[:10], output[:10]], axes):
    for img, ax in zip(images, row):
        ax.imshow(img.reshape((28, 28)).detach().numpy(), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
plt.show()

# Residual neural networks

Now lets check if we can improve even more the performance of the reconstruction by emplying residual neural networks.

TODO