# Imports

In [31]:
import matplotlib.pyplot as plt
import numpy as np


import torch
from torch import optim, nn, utils, Tensor

from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader

from utils.my_transforms.transform import to_numpy


import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

# Dataset and Dataloader

In [2]:
mnist_dataset = MNIST('/datasets/', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

In [None]:
# for idx, (image, label) in enumerate(mnist_dataset):

#     uint8_image = to_numpy(image)*255
#     plt.imshow(uint8_image)
#     plt.axis('off')
#     plt.show()
#     if idx >= 5:
#         break

# Model definition

In [81]:
# Ref. logger: https://learnopencv.com/tensorboard-with-pytorch-lightning/
# Pytorch lightningModule reference: https://pytorch-lightning.readthedocs.io/en/latest/starter/style_guide.html

# define any number of nn.Modules (or use your current ones)
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # self.save_hyperparameters()
        self.z_size = 5

        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(),
                                     nn.Linear(64, self.z_size),
                                     )
        self.decoder = nn.Sequential(nn.Linear(self.z_size, 64), nn.ReLU(),
                                     nn.Linear(64, 28 * 28), nn.Sigmoid(),
                                     )

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch        
        x = x.reshape(x.size(0), -1)
        
        if batch_idx == 0 and self.current_epoch == 0:
            self.reference_train_batch = batch

        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def training_epoch_end(self, outputs):
        # log the model's graph
        # if(self.current_epoch == 1):
        #     sampleImg = torch.rand(((1, 28*28)))
        #     self.logger.experiment.add_graph(LitAutoEncoder(), sampleImg)
        
        # logging histogram
        self.custom_histogram_adder()
        self.custom_show_batch_images(self.reference_train_batch)
        self.custom_log_embedding(self.reference_train_batch)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    # custom logs
    def custom_log_embedding(self, batch):
        ib, lb = batch
        X = ib.reshape(ib.size(0), -1)
        Z = self.encoder(X)

        self.logger.experiment.add_embedding(Z, lb.numpy().tolist(), ib, self.current_epoch, 'latent-space')

    def custom_show_batch_images(self, batch):
        X, _ = batch
        X = X.reshape(X.size(0), -1)
        
        Z = self.encoder(X)
        
        X_hat = self.decoder(Z) #shape: (BATCH_SIZE, 1, 28*28)
        X_hat = X_hat.reshape(-1, 1, 28, 28)

        grid = make_grid(X_hat)
        self.logger.experiment.add_image('X_hat', grid, self.current_epoch, dataformats='CHW')

    def custom_histogram_adder(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

# Init autoencoder

# Model Training

In [82]:
autoencoder = LitAutoEncoder()

logger = TensorBoardLogger('tb-logs', 'mnist-exp')

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
# trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, gpus=1)
trainer = pl.Trainer(limit_train_batches=100,
                     max_epochs=10,
                     logger=logger,
                     num_sanity_val_steps=2,
                     )

trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.6 K
1 | decoder | Sequential | 51.3 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.408     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [None]:
%reload_ext tensorboard
# %tensorboard --logdir=lightning_logs/
%tensorboard --logdir=tb-logs/

# Results

In [None]:
image, label = mnist_dataset[5]

z = autoencoder.encoder(image.reshape(1, -1))
z.shape

In [None]:
z_rnd = Tensor( np.random.normal(size=(1, 3)) )

y_hat = autoencoder.decoder(z_rnd).reshape(1, 28, 28)
y_hat = to_numpy(y_hat.detach())

plt.imshow(y_hat)
plt.axis('off')
plt.show()

In [None]:
y_hat = autoencoder.decoder(z).reshape(1, 28, 28)
y_hat = to_numpy(y_hat.detach())

plt.imshow(y_hat)
plt.axis('off')
plt.show()