This notebook encodes and decodes the MNIST database of handwritten digits from 0 to 9. It uses Pytorch library to implement neural networks. 

# Load Packages

In [None]:
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

# Load Data

In [None]:
# The original MNIST dataset is not loading so this is a work around
# https://stackoverflow.com/a/66820249/7038204
MNIST.resources = [
    ('https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
]

In [None]:
# Load and transform data
batch_size = 32
shuffle = False

img_transform = transforms.Compose([
    transforms.ToTensor(), 
#     transforms.Normalize([0.5], [0.5])
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST('./data', transform=img_transform, download=True, train=True)
val_dataset = MNIST('./data', transform=img_transform, download=True, train=False)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)

# Architecture

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, dim_latent_representation=2):
        
        super(AutoEncoder, self).__init__()
        
        class Encoder(nn.Module):
            def __init__(self, out_size=2):
                super(Encoder, self).__init__()
                self.flatten = nn.Flatten()
                self.first_layer = nn.Sequential(
                    nn.Linear(in_features=28*28, out_features=1024),
                    nn.ReLU()
                )
                self.second_layer = nn.Linear(in_features=1024, out_features=out_size)
            
            def forward(self, x):
                x = self.flatten(x)
                x = self.first_layer(x)
                x = self.second_layer(x)
                return x
            
        class Decoder(nn.Module):
            def __init__(self, in_size=2):
                super(Decoder, self).__init__()
                self.first_layer = nn.Sequential(
                    nn.Linear(in_features=in_size, out_features=1024),
                    nn.ReLU()
                )
                self.second_layer = nn.Sequential(
                    nn.Linear(in_features=1024, out_features=28*28),
                    nn.Sigmoid()
                )
            
            def forward(self, z):
                z = self.first_layer(z)
                z = self.second_layer(z)
                z = z.reshape([z.shape[0],1,28,28])
                return z
            
        self.encoder = Encoder(out_size=dim_latent_representation)
        self.decoder = Decoder(in_size=dim_latent_representation)

    def forward(self, x):
        y = self.encoder(x)
        z = self.decoder(y)
        return z

# Training

In [None]:
learning_rate = 1e-3
weight_decay = 1e-5
num_epochs = 10
dim_latent_representation = 16

loss_criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoEncoder(dim_latent_representation=dim_latent_representation).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


In [None]:
def train(epoch):
    total_loss = 0
    for i, (img, _) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        img = img.to(device)
        optimizer.zero_grad()
        
        output = model(img)
        loss = loss_criterion(output, img)
        
        loss.backward()
        total_loss += loss.data
        optimizer.step()
    
    num_batches = len(train_dataloader.dataset)/batch_size
    training_loss = total_loss/num_batches
    print('Epoch: {}; Training loss: {:.4f}'.format(epoch, training_loss))

In [None]:
def validate(epoch):
    total_loss = 0
    with torch.no_grad():
        for i, (img, _) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
            img = img.to(device)
            output = model(img)
            loss = loss_criterion(output, img)
            total_loss += loss.data
    
    num_batches = len(val_dataloader.dataset)/batch_size
    val_loss = total_loss/num_batches
    print('Epoch: {}; Validation loss: {:.4f}'.format(epoch, val_loss))

In [None]:
for epoch in range(1, num_epochs+1):
    train(epoch)
    validate(epoch)

# Reconstructing images

In [None]:
# Display 28 * 28 images in a row from a numpy array
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def display_images(images, file_path='./tmp.png', display=True):
    save_image(images.view(-1, 1, 28, 28),'{}'.format(file_path))
    
    if display:
        plt.imshow(mpimg.imread('{}'.format(file_path)))


In [None]:
# Show original images.
num_images = 64
original_images = torch.vstack([x for x,_ in val_dataloader])
original_images = original_images[:num_images]

print("Original images")
display_images(original_images.to(device))

In [None]:
# Show reconstructed images.
with torch.no_grad():
    original_images = original_images.to(device)
    reconstructed_images = model(original_images).to(device)

print("Reconstructed images")
display_images(reconstructed_images)