In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Not included by default in Colab
!pip install torchinfo
from torchinfo import summary

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

# For presentations purposes only (not needed in Colab)
#plt.style.use('notebook.mplstyle')
# Keeps the kernel from dying in notebooks on Windows machines (not needed in Colab)
#import os
#os.environ['KMP_DUPLICATE_LIB_OK']='True'

### Download the MNIST dataset

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

### Use PCA to set a baseline for further comparisons

In [None]:
# Get numpy matrices of the data to use with sklearn
X_train = training_data.data.detach().numpy().reshape(60000, 28*28) / 255
X_test = test_data.data.detach().numpy().reshape(10000, 28*28) / 255
y_train = training_data.targets.detach().numpy()
y_test = test_data.targets.detach().numpy()

# Mean center
scaler = StandardScaler(with_std=False)
X_train_mc = scaler.fit_transform(X_train)
X_test_mc = scaler.transform(X_test)

# Fit a PCA model
pca = PCA(n_components=100)
pca.fit(X_train_mc);

Next, we reconstruct the original images using various number of PCA components and compute the mean squarred error (MSE) of the reconstructions.

In [None]:
print('PCA reconstruction error (MSE)')
# Compute PCA scores for each image
X_train_pca = X_train_mc @ pca.components_.T
X_test_pca = X_test_mc @ pca.components_.T

mse_pca_train = []
mse_pca_test = []

# Loop over the number of laten dimensions to evaluate
n_latent_dims_pca = np.arange(10, 101, 10)
for n_latent_dims in n_latent_dims_pca:

    # Reconstruct training images
    X_train_decoded = X_train_pca[:, :n_latent_dims] @ pca.components_[:n_latent_dims, :]
    mse_pca_train.append( np.mean( (X_train_mc-X_train_decoded)**2 ) )

    # Reconstruct test images
    X_test_decoded = X_test_pca[:, :n_latent_dims] @ pca.components_[:n_latent_dims, :]
    mse_pca_test.append( np.mean( (X_test_mc-X_test_decoded)**2 ) )

    print('d={:d} \t Train: {:1.4f} \t Test: {:1.4f}'.format(n_latent_dims, mse_pca_train[-1], mse_pca_test[-1]))

### Visualize what the MSE values corresponds visually

In [None]:
# Create a figure window
fig, axs = plt.subplots(1, 10, figsize=[15, 4])

# Loop over all numbers and plot an example image of each one
for i in range(10):

    # Find an image of the right number
    idx_tmp = np.where(y_test==i)[0][0]
    # Plot the number as an image
    axs[i].imshow(X_test[idx_tmp, :].reshape(28, 28), cmap=cm.Greys_r)
    axs[i].set(xticks=[], yticks=[])
    if i == 0:
        axs[i].set(ylabel='Original')

# Create a figure window
for n_latent_dims in [10, 30, 50, 70, 90]:

    # Create a figure window
    fig, axs = plt.subplots(1, 10, figsize=[15, 4])

    # Loop over all numbers and plot an example image of each one
    for i in range(10):
        
        # Find an image of the right number
        idx_tmp = np.where(y_test==i)[0][0]
        # Decoode
        X_decoded_tmp = X_test_pca[idx_tmp, :n_latent_dims] @ pca.components_[:n_latent_dims, :]
        X_decoded_tmp = scaler.inverse_transform(X_decoded_tmp.reshape(1, 28**2))
        
        # Plot the number as an image
        axs[i].imshow(X_decoded_tmp.reshape(28, 28), cmap=cm.Greys_r, vmin=0, vmax=1)
        axs[i].set(xticks=[], yticks=[])
        if i == 0:
            axs[i].set(ylabel='d={:d}'.format(n_latent_dims))
            

### Define a deep convolutional autoencoder

In [None]:
class AE_ConvNet(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),  # 1x28x28 --> 64x28x28
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), # 64x28x28 --> 64x14x14
            nn.ReLU(True),
            nn.Conv2d(64, 32, 3, stride=2, padding=1), # 64x14x14 --> 32x7x7
            nn.ReLU(True),
            nn.Flatten(start_dim=1),                   # 32x7x7 --> 32*7*7
            nn.Linear(7 * 7 * 32, 256),                # 32*7*7 --> 256
            nn.ReLU(True),
            nn.Linear(256, encoded_space_dim)          # 256 --> encoded_space_dim
        )

        self.decoder = torch.nn.Sequential(
            nn.Linear(encoded_space_dim, 256),                                    # encoded_space_dim --> 256
            nn.ReLU(True),
            nn.Linear(256, 7 * 7 * 32),                                           # 256 --> 32*7*7
            nn.ReLU(True),
            nn.Unflatten(dim=1, unflattened_size=(32, 7, 7)),                     # 32*7*7 --> 32x7x7
            nn.ConvTranspose2d(32, 64, 3, stride=2, padding=1, output_padding=1), # 32x7x7 --> 64x14x14
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1), # 64x14x14 --> 64x28x28
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 3, stride=1, padding=1),                     # 64x28x28 --> 1x28x28
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        # Sigmoid activation function to keep the output between 0 and 1
        x = torch.sigmoid(x)
        return x

### Initialize the model

In [None]:
# The number of images to process in one batch 
# before making a model parameter update.
batch_size = 128

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

# Set parameters
n_latent_dim = 20

# Initialize the model
model = AE_ConvNet(encoded_space_dim=n_latent_dim)
# Print the model summary
print(summary(model, input_size=(batch_size, 1, 28, 28)))

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Move the model to the GPU
model.to(device);

### Define functions for evaluating and visualizing progress

In [None]:
# Training function
def train_epoch(model, device, dataloader, loss_fn, optimizer, scheduler):

    # Set model in train mode
    model.train()
    
    train_loss = []
    # Iterate the dataloader 
    for batch, (image_batch, _) in enumerate(dataloader):

        size = len(dataloader.dataset)
        
        # Move tensor to the proper device
        image_batch = image_batch.to(device)

        # Run the model
        decoded_data = model(image_batch)

        # Evaluate loss
        loss = loss_fn(decoded_data, image_batch)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Detach the loss variable so we can save it
        loss = loss.cpu().detach().numpy()

        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
        if batch % 50 == 0:
            loss, current = loss.item(), batch * len(image_batch)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

        train_loss.append(loss)

    return np.mean(train_loss)

# Testing function
def test_epoch(model, device, dataloader, loss_fn):

    # Set the model in evaluation mode
    model.eval()

    # Speed up things by not computing gradients
    with torch.no_grad():

        test_loss = []
        # Iterate the dataloader 
        for batch, (image_batch, _) in enumerate(dataloader):

            # Move tensor to the proper device
            image_batch = image_batch.to(device)

            # Run the model
            decoded_data = model(image_batch)
            
            # Evaluate loss
            loss = loss_fn(decoded_data, image_batch)
            # Detach the loss variable so we can save it
            loss = loss.cpu().detach().numpy()
            test_loss.append(loss.item())

    return np.mean(test_loss)

def plot_ae_outputs(model):
    
    plt.figure(figsize=(16, 4.5))
    targets = test_data.targets.numpy()
    target_idx = {i:np.where(targets==i)[0][0] for i in range(10)}
    for i in range(10):
        
        img = test_data[target_idx[i]][0].unsqueeze(0).to(device)
        model.eval()
        with torch.no_grad():
            rec_img = model(img)
            
        ax = plt.subplot(2, 10, i+1)
        plt.imshow(img.cpu().squeeze().numpy(), cm.Greys_r, vmin=0, vmax=1)
        ax.set(xticks=[], yticks=[]) 
        if i == 5:
            ax.set_title('Original images')
            
        ax = plt.subplot(2, 10, i + 1 + 10)
        plt.imshow(rec_img.cpu().squeeze().numpy(), cm.Greys_r, vmin=0, vmax=1)  
        ax.set(xticks=[], yticks=[]) 
        if i == 5:
            ax.set_title('Reconstructed images')
            
    plt.show()

### Train the network
We will train the whole network (encoder and decoder) at the same time.

In [None]:
# Define the loss function
# Use MSE to minimize the reconstruction error
loss_fn = torch.nn.MSELoss()

n_epochs = 15
# Use Adam as the default optimizer
optim = torch.optim.Adam(model.parameters(), lr=2e-4)
# Use a scheduler to change the learnign rate over time
scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=0.001, 
                                                steps_per_epoch=len(train_dataloader), epochs=n_epochs, 
                                                div_factor=10, final_div_factor=100)

for epoch in range(n_epochs):
    
    train_loss = train_epoch(model, device, train_dataloader, loss_fn, optim, scheduler)
    test_loss = test_epoch(model, device, test_dataloader, loss_fn)
    
    print('\n EPOCH {}/{} \t train loss {} \t test loss {}'.format(epoch + 1, n_epochs, train_loss, test_loss))

    plot_ae_outputs(model)

### Test to reconstruct a random image
Autoencoder can be used for anomaly detection if they are trained on "normal" data. The logic being that they are only trained to reconstruct the "normal" training data well, and will thus have large reconstruction errors for data that deviates from what was used during training.

In [None]:
# Create a figure window
fig, axs = plt.subplots(1, 2)

X_rnd = np.random.rand(28, 28)
#_rnd = np.zeros([28, 28])
#X_rnd[:, 5:23] = 1.

# Plot the input image
X_rnd = X_rnd.reshape(1, 28, 28)
axs[0].imshow(X_rnd[0, :, :], cmap=cm.Greys_r, vmin=0, vmax=1)
axs[0].set(xticks=[], yticks=[], title='Input')

# Set the model in evaluation mode
model.eval()
# Speed up things by not computing gradients
with torch.no_grad():
    test_img = torch.Tensor(X_rnd).float().unsqueeze(0)
    decoded_img = model(test_img.to(device)).cpu().detach().numpy()

# Plot the reconstructed image
axs[1].imshow(decoded_img[0, 0, :, :], cmap=cm.Greys_r, vmin=0, vmax=1)
axs[1].set(xticks=[], yticks=[], title='Output')