In this notebook we will practice working with autoencoders.

# Installations and imports

In [None]:
! pip install -q pytorch_lightning
! pip install -q torchvision

In [None]:
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline

# Simple AutoEncoder implementation

Firstly, there is an example of the AE implementation.

In [None]:
# Encoder

class SimpleEncoder(nn.Module):
    def __init__(self, input_shape, code_size):
        super().__init__()
        self.input_shape = input_shape
        self.code_size = code_size

        # Calculate the flattened size
        self.flattened_size = 1
        for x in self.input_shape:
            self.flattened_size *= x

        self.input_to_representation = nn.Linear(self.flattened_size, self.code_size)

    def forward(self, image_batch):
        flattened = image_batch.view(-1, self.flattened_size)
        representation = F.relu(self.input_to_representation(flattened))
        return representation

In [None]:
# Decoder

class SimpleDecoder(nn.Module):
    def __init__(self, input_shape, code_size):
        super().__init__()
        self.input_shape = input_shape
        self.code_size = code_size

        # Calculate the flattened size
        self.flattened_size = 1
        for x in self.input_shape:
            self.flattened_size *= x

        self.representation_to_output = nn.Linear(self.code_size, self.flattened_size)

    def forward(self, representation):
        flat_reconstructed = F.sigmoid(self.representation_to_output(representation))
        reconstructed = flat_reconstructed.view(-1, *self.input_shape)
        return reconstructed

In [None]:
# AutoEncoder

class SimpleAutoEncoder(pl.LightningModule):
    def __init__(self, input_shape, code_size):
        super().__init__()

        self.save_hyperparameters()  # save input_shape, code_size

        self.input_shape = input_shape
        self.code_size = code_size

        # Calculate the flattened size
        flattened_size = 1
        for x in self.input_shape:
            flattened_size *= x

        self.encoder = SimpleEncoder(input_shape, code_size)
        self.decoder = SimpleDecoder(input_shape, code_size)

    def forward(self, image_batch):
        return self.decoder(self.encoder(image_batch))

    def training_step(self, batch, batch_idx):
        batch_images = batch[0]
        reconstructed_images = self.forward(batch_images)
        loss = F.mse_loss(reconstructed_images, batch_images)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        batch_images = batch[0]
        reconstructed_images = self.forward(batch_images)
        loss = F.mse_loss(reconstructed_images, batch_images)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        batch_images = batch[0]

        reconstructed_images = self.forward(batch_images)
        loss = F.mse_loss(reconstructed_images, batch_images)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return optim.Adam(self.parameters())

In [None]:
# Data module

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=64, num_workers=4):

        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # We hardcode dataset specific stuff here.
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
# Configuration

trainer_kwargs = {
    'gpus': 1,
    'max_epochs': 10,
    'precision': 16,
    'progress_bar_refresh_rate': 5,
    'weights_summary': "full"
}

mnist_dm = MNISTDataModule()

ae_kwargs = {
    'input_shape': mnist_dm.size(),
    'code_size': 128
}

In [None]:
# Training

model = SimpleAutoEncoder(**ae_kwargs)
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(model, mnist_dm)

In [None]:
# Evaluation

model.eval()
test_loss = trainer.test(model, mnist_dm)

In [None]:
# Plots
trans = transforms.ToPILImage()


def model_plots(model):
    model.eval()
    for batch in mnist_dm.val_dataloader():
        original_imgs = batch[0]
        outputs = model(original_imgs)
        for i in range(len(outputs)):
            f = plt.figure()
            f.add_subplot(1, 2, 1)
            plt.imshow(trans(original_imgs[i]).convert("RGB"))
            plt.title('origin')
            plt.axis('off')
            f.add_subplot(1, 2, 2)
            plt.imshow(trans(outputs[i]).convert("RGB"))
            plt.title('AE')
            plt.axis('off')
            plt.show(block=True)
            if i == 3:
                break
        break

In [None]:
model_plots(model)

Looks that even the simplest autoencoder works.

# Task 1 (1/5 points)

Now it's your turn! The first task is to build a graph of the dependence of the loss on the dimension of the code (leave the rest of the parameters as in the example).

In [None]:
code_size_list = [32, 64, 128, 256, 512, 1024]

In [None]:
test_loss_list = <<your code here>>

In [None]:
plt.plot(code_size_list, test_loss_list)
plt.xlabel('code size')
plt.xlabel('loss')
plt.title('Dependence loss on code size')

Please note that with an increase in the size of the code, the loss falls, but if the size is too large, the loss begins to grow for the selected architecture.

# Task 2 (3/5 points)

Create your own encoder and decoder models, use convolutional layers. Do not change the training parameters and try to get the smallest loss!

In [None]:
# Encoder

class CoolEncoder(nn.Module):
    <<your code here>>


In [None]:
## Decoder

class CoolDecoder(nn.Module):
    <<your code here>>
    

In [None]:
encoder = CoolEncoder(input_shape=mnist_dm.size(), code_size=128)
assert encoder(torch.rand([1, 1, 28, 28])).size() == torch.Size([1, 128])
encoder = CoolEncoder(input_shape=mnist_dm.size(), code_size=10)
assert encoder(torch.rand([1, 1, 28, 28])).size() == torch.Size([1, 10])

decoder = CoolDecoder(input_shape=mnist_dm.size(), code_size=128)
assert decoder(torch.rand((1, 128))).size() == torch.Size([1, 1, 28, 28])
decoder = CoolDecoder(input_shape=mnist_dm.size(), code_size=10)
assert decoder(torch.rand((1, 10))).size() == torch.Size([1, 1, 28, 28])

In [None]:
# AutoEncoder

class CoolAutoEncoder(SimpleAutoEncoder):
    def __init__(self, input_shape, code_size):
        super().__init__(input_shape, code_size)
        self.encoder = CoolEncoder(input_shape, code_size)
        self.decoder = CoolDecoder(input_shape, code_size)

In [None]:
# Training

model = CoolAutoEncoder(**ae_kwargs)
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(model, mnist_dm)

In [None]:
# Evaluation

model.eval()
test_loss = trainer.test(model, mnist_dm)[0]['test_loss']

In [None]:
# Assignment

if test_loss <= 0.002:
    print('BEYOND GODLIKE!!! (3 point)')
elif test_loss <= 0.0035:
    print('GODLIKE! (2.5 point)')
elif test_loss <= 0.005:
    print('Unstoppable (2 point)')
elif test_loss <= 0.065:
    print('Dominating (1 point)')
elif test_loss > 0.08:
    print('Try again =)')
    assert False

In [None]:
# Plots

model_plots(model)

# Denoising


In [None]:
STD = 0.005
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def noise_generator(size, std):
    return torch.normal(0, std, size)

# Task 3 (1/5 points)

fill gaps to obtain denoising model (you could previous models).

In [None]:
class DenoisingAutoEncoder(SimpleAutoEncoder):
    def __init__(self, input_shape, code_size):
        super().__init__(input_shape, code_size)
        self.encoder = <<your code here>>
        self.decoder = <<your code here>>

    def forward(self, image_batch):
        if self.training:
            noise = noise_generator(image_batch.size(), STD)
            if image_batch.is_cuda:
                noise = noise.cuda()
            image_batch = image_batch + noise
        x = self.encoder(image_batch)
        return self.decoder(x)

In [None]:
model = DenoisingAutoEncoder(**ae_kwargs)
trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(model, mnist_dm)

In [None]:
model.train()
test_loss = trainer.test(model, mnist_dm)[0]['test_loss']
assert test_loss <= 0.003

In [None]:
def denoising_model_plots(model):
    model.eval()
    for batch in mnist_dm.val_dataloader():
        original_imgs = batch[0]
        noised_imgs = batch[0] + noise_generator(original_imgs.size(), STD)
        outputs = model(noised_imgs)
        for i in range(len(outputs)):
            f = plt.figure()
            f.add_subplot(1, 3, 1)
            plt.imshow(trans(noised_imgs[i]).convert("RGB"))
            plt.title('noised origin')
            f.add_subplot(1, 3, 2)
            plt.imshow(trans(outputs[i]).convert("RGB"))
            plt.title('AE denoising')
            f.add_subplot(1, 3, 3)
            plt.imshow(trans(original_imgs[i]).convert("RGB"))
            plt.title('origin')
            if i == 3:
                break
        break

In [None]:
denoising_model_plots(model)

Looks useful.

# Some links

* https://www.youtube.com/watch?v=E2d8NRYt2e4
* https://pytorch-lightning.readthedocs.io/en/stable/notebooks/course_UvA-DL/08-deep-autoencoders.html?highlight=autoencoder