### Same task as in `AutoencoderMNIST.ipynb`, this time implemented in lightning

In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl

In [None]:
dataset = datasets.MNIST(root = "./data",
                         train = True,
                         download = True,
                         transform = transforms.ToTensor()) #
# dataset, _ =\
#                 torch.utils.data.random_split(dataset, (int(0.04*len(dataset)), int(0.96*len(dataset))))

train_set, test_set, valid_set =\
                torch.utils.data.random_split(dataset, (int(0.7*len(dataset)), int(0.15*len(dataset)), int(0.15*len(dataset))))

train_loader = torch.utils.data.DataLoader(dataset = train_set, batch_size = 32)
validation_loader = torch.utils.data.DataLoader(dataset = valid_set, batch_size = 32)
test_loader = torch.utils.data.DataLoader(dataset = test_set, batch_size = 32)

In [None]:
Nfeatures = 28*28 #size of the encoder input layer
Layers=[128,64,36,18] #sizes of inner layers
NTargets=10 #size of the encoder output layer

In [None]:
class MNIST_AE(pl.LightningModule): 

### Model ###
    def __init__(self, Nfeatures, Layers, Ntargets):
        super(MNIST_AE, self).__init__() # TODO: if not "cannot assign module before Module.__init__() call"
        # Initialize layers
        self.encoderIn = torch.nn.Linear(Nfeatures, Layers[0]) #first layer 28*28 -> 128
        self.encoderl1 = torch.nn.Linear(Layers[0], Layers[1])
        self.encoderl2 = torch.nn.Linear(Layers[1], Layers[2])
        self.encoderl3 = torch.nn.Linear(Layers[2], Layers[3])
        self.encoderOut = torch.nn.Linear(Layers[3], Ntargets)

        self.decoderIn = torch.nn.Linear(Ntargets, Layers[3]) #f
        self.decoderl1 = torch.nn.Linear(Layers[3], Layers[2])
        self.decoderl2 = torch.nn.Linear(Layers[2], Layers[1])
        self.decoderl3 = torch.nn.Linear(Layers[1], Layers[0])
        self.decoderOut = torch.nn.Linear(Layers[0], Nfeatures)
        # TODO: better place to define mse_loss
        self.mse_loss = torch.nn.MSELoss(reduction = 'mean')
        self.validation_step_outputs = []


    def forward(self, x):
        x = torch.relu(self.encoderIn(x))
        x = torch.relu(self.encoderl1(x))
        x = torch.relu(self.encoderl2(x))
        x = torch.relu(self.encoderl3(x))
        x = self.encoderOut(x)
        x = torch.relu(self.decoderIn(x))
        x = torch.relu(self.decoderl1(x))
        x = torch.relu(self.decoderl2(x))
        x = torch.relu(self.decoderl3(x))
        x = self.decoderOut(x)
        return x

### The Optimizer ### 
    def configure_optimizers(self):
        #optimizer = torch.optim.Adam(self.parameters(), lr=0.05)#l_rate) # TODO: should be a parameter
        optimizer = torch.optim.Adam(self.parameters(),
                             lr = 1e-2,
                             weight_decay = 1e-8)
        return optimizer

### Training ### 
    def training_step(self, batch, batch_idx):
        images, label = batch
        images = images.reshape(-1, 28*28)
        # Evaluate physical model using data scaling
        logits = self.forward(images)
        # Evaluate loss comparing to the kinetic heat flux in y
        loss = self.mse_loss(logits, images)
        # Add logging
        self.log("train_loss", loss)
        logs = {'loss': loss}
        return {'loss': loss, 'log': logs}
    
    def test_step(self, batch, batch_idx):
        images, label = batch
        images = images.reshape(-1, 28*28)
        logits = self.forward(images)
        loss = self.mse_loss(logits, images)
        self.log("test_loss", loss)
### Validation ### 
    def validation_step(self, batch, batch_idx):
        images, label = batch
        images = images.reshape(-1, 28*28)
        # Evaluate physical model using data scaling
        logits = self.forward(images)
        # Evaluate loss comparing to the kinetic heat flux in y
        loss = self.mse_loss(logits, images)
        self.validation_step_outputs.append(loss)
        return {'val_loss': loss}

    # Define validation epoch end
    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.validation_step_outputs).mean()
        self.log("validation_epoch_average", avg_loss)
        self.validation_step_outputs.clear()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

In [None]:
model=MNIST_AE(Nfeatures, Layers, NTargets)

In [None]:
len(model.state_dict())

In [None]:
pl.utilities.model_summary.ModelSummary(model,max_depth=3)

In [None]:
for name in model.state_dict():
    print(name)

In [None]:
%load_ext tensorboard

trainer = pl.Trainer(max_epochs = 7)
trainer.fit(model, train_loader, validation_loader)

In [None]:
def plot_digits(*args): #just plotting the result

    n = min([x.shape[0] for x in args]) #n=5
    
    plt.figure(figsize=(2*n, 2*len(args))) #10 na 4
    for j in range(n): #j [1,2,3,4,5]
        for i in range(len(args)): #i [1,2]
            img=args[i][j].reshape(-1,28,28).detach().numpy()
            ax = plt.subplot(len(args), n, i*n + j + 1) #arguments: nrows, ncols, index
            plt.imshow(img[0])
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    plt.show()

xbatch, ybatch =next(iter(test_loader))
xbatch=xbatch.reshape(-1,28*28)
a=[model(x) for x in xbatch[:5]]
a=torch.stack(a, dim=0)

plot_digits(xbatch[:5], a)

In [None]:
val_result = trainer.test(model, dataloaders=validation_loader, verbose=False)
test_result = trainer.test(model, dataloaders=test_loader, verbose=False)