### Implemention autoencoder on the MIST dataset (images of handwritten digits)

#### Import the relevant packages and define the device

In [1]:
from torch_snippets import *
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#### Define the transformation required for the images 

In [2]:
images_tranform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5,), (0.5,)),
                                      transforms.Lambda(lambda x: x.to(device))])

#### Define the training and validation dataset

In [3]:
datafolder = 'dataset/'
train_dataset = MNIST(datafolder, train=True, download=True, transform=images_tranform)
valid_dataset = MNIST(datafolder, train=False, download=True, transform=images_tranform)

#### Define the dataloader

In [4]:
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

#### Define the network architecture 

In [5]:
class AutoEncoder(nn.Module):

    def __init__(self, latent_dim) :
        super().__init__()

        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
                                    nn.Linear(28*28 , 128),
                                    nn.ReLU(True),
                                    nn.Linear(128, 64), 
                                    nn.ReLU(True),
                                    nn.Linear(64, latent_dim))
        self.decoder = nn.Sequential(
                                    nn.Linear(latent_dim, 64),
                                    nn.ReLU(True),
                                    nn.Linear(64, 128),
                                    nn.ReLU(True),
                                    nn.Linear(128, 28*28),
                                    nn.Tanh())
    
    def forward(self, x):
        x = x.view(len(x), -1)
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(len(x), 1, 28, 28)

        return x


#### Visualize the preceding model

In [9]:
from torchsummary import summary
model = AutoEncoder(3).to(device)
summary(model, (1,1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 128]         100,480
              ReLU-2                  [-1, 128]               0
            Linear-3                   [-1, 64]           8,256
              ReLU-4                   [-1, 64]               0
            Linear-5                    [-1, 3]             195
            Linear-6                   [-1, 64]             256
              ReLU-7                   [-1, 64]               0
            Linear-8                  [-1, 128]           8,320
              ReLU-9                  [-1, 128]               0
           Linear-10                  [-1, 784]         101,136
             Tanh-11                  [-1, 784]               0
Total params: 218,643
Trainable params: 218,643
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/