#  Fitting a Convolutional Autoencoder (CNN-AE) with Pytorch


## Import libraries, including PyTorch
Setup Pytorch guidance is here: https://pytorch.org/get-started/locally/)

In [None]:
import pandas as pd
from tqdm import trange, tqdm
import torch
torch.set_num_threads(2)
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F                # contains activation functions, sampling layers etc
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid    
from torch.utils.data import DataLoader        # to load data into batches (for SGD)
import torch.optim as optim                    # For optimization routines such as SGD, ADAM, ADAGRAD, etc
from torch.utils.data import DataLoader,random_split

### Load in the MNIST digits data

Like yesterday, we use the MNIST digits dataset. We split the training dataset into training and validation sets. 

In [None]:
data_dir = 'dataset'

train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_transform = transforms.Compose([
transforms.ToTensor(),
])

test_transform = transforms.Compose([
transforms.ToTensor(),
])

train_dataset.transform = train_transform
test_dataset.transform = test_transform

m = len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])
batch_size = 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

## View basic data set information

In [None]:
print(train_dataset)

### Below we create the encode and decoder objects

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
    
    
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

### Below we define the loss function, learning rate and latent space dimension


In [None]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr= 0.001

### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks and set the latent dimension 
d = 4

#model = Autoencoder(encoded_space_dim=encoded_space_dim)
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)

# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

### Below we define how the training is done and collect the training loss from each batch of data

In [None]:
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss
        loss = loss_fn(decoded_data, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss)

### We also compute the loss over validation data to determine when is a good time to stop training

In [None]:
### Testing function
def test_epoch(encoder, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

### Below is a function to plot the images recontructed by the autoencoder

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()   

### Below is some code to train for a number of epochs (an epoch is one complete pass through the data) 

To save time we have commented this out and loaded in the trained parameters in the next block. The loaded in data was obtained by training for 20 epochs. 

In [None]:
# num_epochs = 20
# diz_loss = {'train_loss':[],'val_loss':[]}
# for epoch in range(num_epochs):
#    train_loss = train_epoch(encoder,decoder,device,
#    train_loader,loss_fn,optim)
#    val_loss = test_epoch(encoder,decoder,device,test_loader,loss_fn)
#    print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, num_epochs,train_loss,val_loss))
#    diz_loss['train_loss'].append(train_loss)
#    diz_loss['val_loss'].append(val_loss)
#    plot_ae_outputs(encoder,decoder,n=10)

In [None]:
### Code below to save the models

In [None]:
# torch.save(encoder.state_dict(), './saved_models/cnn_ae_encoder')
# torch.save(decoder.state_dict(), './saved_models/cnn_ae_decoder')

### Below we load in the pre-trained models

In [None]:
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
encoder.load_state_dict(torch.load('./saved_models/cnn_ae_encoder'))
decoder.load_state_dict(torch.load('./saved_models/cnn_ae_decoder'))
encoder.to(device)
decoder.to(device)
encoder.eval()
decoder.eval()

plot_ae_outputs(encoder,decoder,n=10)

### Below is the test error

In [None]:
test_epoch(encoder,decoder,device,test_loader,loss_fn).item()

### Below we plot the training and validation loss

This is commented out since we have pre-trained the models

In [None]:
# plt.figure(figsize=(10,8))
# plt.semilogy(diz_loss['train_loss'], label='Train')
# plt.semilogy(diz_loss['val_loss'], label='Valid')
# plt.xlabel('Epoch')
# plt.ylabel('Average Loss')
# plt.legend()
# plt.show()

### Below we draw some images from random points in the latent space. 

Since the autoencoder latent space can be quite irregular and therefore some of the generated samples don't look so good. 

In [None]:
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

encoder.eval()
decoder.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    testiter = iter(test_loader)
    images, labels = next(testiter)
    images = images.to(device)
    latent = encoder(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    print(f"means: {mean}")
    std = (latent - mean).pow(2).mean(dim=0).sqrt()
    print(f"standard deviations: {std}")

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, d)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(device)
    img_recon = decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

### Below we look at where the different digits lie in the latent space

Note that the model is unsupervised and the labels were not used for training, but it can be observed that there is some reasonable clustering of the digits in the latent space. 

We have now put the encoded test data in a pandas dataframe to use the methods from the unsupervised learning labs to cluster and visualise the latent space

In [None]:
test_subset_size = 1000 # We just take 1000 samples from the test data to keep computation time down
encoded_samples = []
for i in tqdm(range(1,test_subset_size+1)):
    sample_idx = torch.randint(len(test_dataset), size=(1,)).item()
    sample = test_dataset[sample_idx]
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
    
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)


### Below we put the labels in the index of the dataframe

Now the encoded test data is ready to be clustered or visualised as you wish. The format of the data is the same as you used in the PCA and clustering labs

In [None]:
encoded_data = encoded_samples.set_index('label')
labels = encoded_data.index.astype(str)
encoded_data