# Training an Autoencoder model

In this demo, we go through the process of training an autoencoder model. We use a basic architecture for PyTorch.

As the input data, we use the Fashion MNIST dataset, as it demonstrates the possibilities of autoencoders without the hassle of finding new data.

If you would like to explore more datasets to train your autoencoder, please take a look at the torchvision library here: https://pytorch.org/vision/stable/datasets.html

This tutorial is based on the tutorial from: https://www.geeksforgeeks.org/implementing-an-autoencoder-in-pytorch/

With the following modifications:
* changed the dataset to Fashion MNIST
* added visualization of the autoencoder
* minor bug-fixes

In [None]:
# installing the visualization module that
# we will later use for drawing the autoencoder architecture
!pip install -q torchviz

In [None]:
# imports needed for the autoencoder itself
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

# standard libraries
import pandas as pd
import numpy as np

# imports needed for the visualizaation of the network
from torchviz import make_dot

In [None]:
!pip install -q datasets

In [None]:
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()

# Download the MNIST Dataset
dataset = datasets.FashionMNIST(root = "./data",
                         train = True,
                         download = True,
                         transform = tensor_transform)

# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = 32,
                                     shuffle = True)

In [None]:
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 9
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )

        # Building an linear decoder with Linear
        # layer followed by Relu activation function
        # The Sigmoid activation function
        # outputs the value between 0 and 1
        # 9 ==> 784
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:

# Model Initialization
model = AE()

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()

# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8)

In [None]:
epochs = 10
outputs = []
losses = []
for epoch in range(epochs):
    for (image, _) in loader:

      # Reshaping the image to (-1, 784)
      image = image.reshape(-1, 28*28)

      # Output of Autoencoder
      reconstructed = model(image)

      # Calculating the loss function
      loss = loss_function(reconstructed, image)

      # The gradients are set to zero,
      # the gradient is computed and stored.
      # .step() performs parameter update
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # Storing the losses in a list for plotting
      losses.append(loss)
    outputs.append((epochs, image, reconstructed))


In [None]:
# Defining the Plot Style
plt.style.use('seaborn')
plt.xlabel('Iterations')
plt.ylabel('Loss')

# Convert the list to a PyTorch tensor
losses_tensor = torch.tensor(losses)

# Plotting the last 100 values
plt.plot(losses_tensor.detach().numpy()[::-1])

In [None]:
for i, item in enumerate(image):

  # Reshape the array for plotting
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])

In [None]:
from torchsummary import summary
summary(model, (1, 28 * 28))

In [None]:
!pip install torchviz
from torchviz import make_dot


In [None]:
# batch = next(iter(dataloader_train))
yhat = model(image[0]) # Give dummy batch to forward().

In [None]:
from torchviz import make_dot

make_dot(yhat,
         params=dict(list(model.named_parameters())),
         show_attrs=True,
         show_saved=True)

In [None]:
make_dot(yhat).render(format='png')