Hi! In this notebook we will train the autoencoder to encode and decode digit 1. Then we input digit 8, which will be "anomalous" among digits 1 and compare the resulting images. We need the anomaly rate for 8 to be significantly higher than for 1.

In [None]:
! pip install pytorch-lightning

In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
import pylab
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.optim as optim

# Set params

In [None]:
BATCH_SIZE = 16
Z_DIM = 64 # autoencoder code size
CUDA = True

# Datasets

In [None]:
class MNIST_anom_dataset(Dataset):
    def __init__(self, datasets, labels:list):
        self.dataset = [datasets[i][0] for i in range(len(datasets))
                        if datasets[i][1] in labels ]
        self.labels = labels
        self.len_oneclass = int(len(self.dataset)/10)

    def __len__(self):
        return int(len(self.dataset))

    def __getitem__(self, index):
        img = self.dataset[index]
        return img,[]

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )) 
])

train_dataset = MNIST('./data', download=True,train=True, transform=transform)
train_1 = MNIST_anom_dataset(train_dataset,[1])
train_loader = DataLoader(train_1, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = MNIST('./data', train=False,download=True, transform=transform)
test_8 = MNIST_anom_dataset(test_dataset,[8])
test_loader = DataLoader(test_8, batch_size=len(test_dataset), shuffle=True)

# Autoencoder model (Task: 2/3 points)

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, z_dim):
        super(Autoencoder, self).__init__()
        # FOR STUDENTS: make your encoder. 
        # size: 28 * 28 -> z_dim
        # note: in this task we do not need a strong complex encoder – linear layers are enough.
        self.encoder = # your code

        # FOR STUDENTS: make your decoder. z_dim -> 28 * 28
        # size: z_dim -> 28 * 28
        # note: use tanh activation at the end 
        # note: in this task we do not need a strong complex encoder – linear layers are enough.
        self.decoder = # your code

    def forward(self, x):
        z = self.encoder(x) # encoding
        xhat = self.decoder(z) # decoding
        return xhat # image after autoencoder

# Trainer (Task: 1/3 points)

In [None]:
class AutoencoderTrainer(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.MSELoss()

    def training_step(self, batch, batch_idx):
        # FOR STUDENTS: make the training step
        # loss: MSE between original image (x) 
        # and image after the autoencoder (x and xhat) 
        
        # your code
        loss =  # your code
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=3.0e-4,
            weight_decay=1e-5
        )
        return optimizer

# Train model

In [None]:
model = Autoencoder(Z_DIM)

trainer = pl.Trainer(
    devices=1, 
    accelerator="gpu" if CUDA else "cpu",
    max_epochs=10,
)
trainer.fit(
    AutoencoderTrainer(model),
    train_loader,
)

# Result

In [None]:
def anomaly_score(x, xhat):
    abs_diff =  np.abs(x[i] - xhat[i])
    return np.sum(abs_diff)

In [None]:
if CUDA:
    model = model.to('cuda')

# apply model on 1 digit images
i = 0
digit_1_results = []
for img, _ in train_1:
    x = img.view(img.size(0), -1)
    if CUDA:
        x = Variable(x).to('cuda')
    else:
        x = Variable(x)

    xhat = model(x)
    x = x.cpu().detach().numpy()
    xhat = xhat.cpu().detach().numpy()
    x = x/2 + 0.5
    xhat = xhat/2 + 0.5

    score = anomaly_score(x, xhat)
    digit_1_results.append((x, xhat, score))


# apply model on 8 digit images
i = 0
digit_8_results = []
for img, _ in test_8:
    x = img.view(img.size(0), -1)
    if CUDA:
        x = Variable(x).to('cuda')
    else:
        x = Variable(x)

    xhat = model(x)
    x = x.cpu().detach().numpy()
    xhat = xhat.cpu().detach().numpy()
    x = x/2 + 0.5
    xhat = xhat/2 + 0.5

    score = anomaly_score(x, xhat)
    digit_8_results.append((x, xhat, score))


In [None]:
average_1_score = np.average([s[2] for s in digit_1_results])
average_8_score = np.average([s[2] for s in digit_8_results])
print("average 1 score:", average_1_score)
print("average 8 score:", average_8_score)
assert average_8_score > average_1_score * 5
print(f"great - the average score for digit 8 is {average_8_score/average_1_score:.3f} times higher than for 1!")

In [None]:
n = 5

for name, results in zip(
    ["digit 1", "digit 8"],
    [digit_1_results, digit_8_results]
):

    plt.figure(figsize=(12, 6))
    for i in range(n):
        x = results[i][0]
        xhat = results[i][1]
        score = results[i][2]

        # plot x
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(x.reshape(28, 28))
        plt.gray()
        if i == 0:
            ax.set_ylabel('x')
        else:
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

        # plot xhat
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(xhat.reshape(28, 28))
        plt.gray()
        if i == 0:
            ax.set_ylabel('xhat')
        else:
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

        # plot diff
        ax = plt.subplot(3, n, i + 1 + n * 2)
        diff_img = np.abs(x - xhat)
        plt.imshow(diff_img.reshape(28, 28),cmap="jet")
        if i == 0:
            ax.set_ylabel('diff')
        else:
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

        # plot score
        ax.get_xaxis().set_visible(True)
        ax.get_yaxis().set_visible(True)
        ax.set_xlabel(f'score: {score:.3f}')

    plt.suptitle(name)
    plt.show()
    plt.close()