In [None]:
# Auto Encoder for MNIST dataset

import numpy as np
import torch
import torch.nn as nn

from matplotlib import pyplot as plt

import pytorch_lightning as pl
import torchvision
from torchvision.datasets import MNIST

from torch.utils.data import DataLoader, random_split

# Create a PyTorch Lightning class
class AutoEncoder(pl.LightningModule):
    def __init__(self, input_shape, num_hidden):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_shape, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, input_shape),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Forward pass
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        # Training step
        x, y = batch
        x = nn.Flatten()(x)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Validation step
        x, y = batch
        x = nn.Flatten()(x)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("val_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Testing step
        x, y = batch
        x = nn.Flatten()(x)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("test_loss", loss)
        # Plot the first 8 images in the batch
        if batch_idx == 0:
            images = x[:8]
            reconstructions = x_hat[:8]
            for i, (im, recon) in enumerate(zip(images, reconstructions)):
                im = im.reshape(28, 28)
                recon = recon.reshape(28, 28)
                plt.subplot(2, 8, i + 1)
                plt.imshow(im.detach().cpu().numpy(), cmap="gray")
                plt.subplot(2, 8, 8 + i + 1)
                plt.imshow(recon.detach().cpu().numpy(), cmap="gray")
            self.logger.experiment.add_figure("reconstructions", plt.gcf(), self.current_epoch)
            plt.clf()
        return loss
    
# Create the datamodule for MNIST dataset
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.transform = nn.Sequential(
            torchvision.transforms.ToTensor(),
            nn.Normalize((0.1307,), (0.3081,)),
        )

    def prepare_data(self):
        # Download the dataset
        self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=self.transform)
        self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=self.transform)

    def setup(self, stage=None):
        # Split the dataset
        mnist_full = MNIST(os.getcwd(), train=True, download=False, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        # Training dataloader
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        # Validation dataloader
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        # Testing dataloader
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
        return mnist_test
    
# Create the model
model = AutoEncoder(input_shape=28*28, num_hidden=128)
dataset = MNISTDataModule()
trainer = pl.Trainer(max_epochs=10, gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model, dataset)
