In [None]:
from torch import nn, optim, tensor
from torch.utils.data import DataLoader, random_split

from torchmetrics.functional import accuracy

from torchvision import datasets, transforms

import pytorch_lightning as pl
import os

In [None]:
# Pytorch lightning class
class LinearImageClassifier(pl.LightningModule):
    @staticmethod
    def _get_abs_split(dataset, train_fraction):
        """
        Returns the absolute value of items to split into (train, test) from the input dataset

        Arguments
        ---------
        dataset : The torchvision dataset to split.

        train_fraction : Float
            The fraction of data points that will be used for training. Must be between [0, 1].

        Returns
        -------
        train_size, test_size : Tuple (int, int)
            The absolute number of data points to in train vs test.
        """
        size = dataset.data.size()[0]
        train_size = int(size * train_fraction)
        test_size = size - train_size
        return train_size, test_size

    def __init__(
            self,
            training_data,
            linear_dims,
            train_fraction = 0.8,
            dropout_rate = 0.1,
            create_residual_connection = False,
            learning_rate = 1e-2,
            batch_size = 10):
        """

        Arguments
        ---------
        training_data : TorchVision dataset or similar

        linear_dims : List[int]
            input/output dimensions of linear units in sequence.
            E.g., for a 2 layer lienear network it would be 
            [ input_dims, linear1_dims, linear2_dims, output_dims ]
            Note: the current implementation does not support more than 5 hidden linear layers.

        Parameters
        ----------
        dropout_rate : Float (Default = 0)
            Can be set to a value between [0, 1] to add a final dropout layer.

        create_residual_connection : Bool (Default = False)
            Will create a residual connection in the dropout layer between the first and last hidden layers.

        learning_rate : Float (Default = 1e-2)

        batch_size : Int (Default = 10)
        """
        super().__init__()

        # Input and hidden layers
        self._layers = nn.ModuleList()
        for i in range(len(linear_dims) - 2):
            self._layers.append(nn.Linear(linear_dims[i], linear_dims[i + 1]))

        # Dropout layer, if any
        self._dropout_rate = dropout_rate
        if self._dropout_rate > 0:
            self._layers.append(nn.Dropout(self._dropout_rate))
        self._create_residual_connection = create_residual_connection

        # Output layer
        self._layers.append(nn.Linear(linear_dims[-2], linear_dims[-1]))

        self._training_data = training_data
        self._train_fraction = train_fraction
        self._total_hidden_layers = len(linear_dims) - 2
        self._learning_rate = learning_rate
        self._batch_size = batch_size

        self._loss = nn.CrossEntropyLoss()

    def prepare_data(self):
        # Can be used to download data but do not assign anything to the model here
        # i.e., no self.something = value here because this method will run only on a single processor.
        # To initialize, transform and assign data use the setup() method instead.
        pass

    def setup(self, stage):
        # This method will run in every GPU and logical processor so use it to assign model parameters.
        # Split the data into train and test
        self._training_data, self._validation_data = random_split(self._training_data, self._get_abs_split(self._training_data, self._train_fraction))

    def train_dataloader(self):
        # Create data loaders for the train and test data
        train_loader = DataLoader(self._training_data, batch_size = self._batch_size, num_workers = os.cpu_count())
        return train_loader

    def val_dataloader(self):
        # Create data loaders for the train and test data
        val_loader = DataLoader(self._validation_data, batch_size = self._batch_size, num_workers = os.cpu_count())
        return val_loader

    def configure_optimizers(self):
        # Define an optimizer
        optimizer = optim.SGD(self.parameters(), lr = self._learning_rate)
        return [ optimizer ]

    def forward(self, x):
        forward_pass = list()
        forward_pass.append(nn.functional.relu(self._layers[0](x)))

        if self._total_hidden_layers > 0:
            forward_pass.append(nn.functional.relu(self._layers[1](forward_pass[0])))
        if self._total_hidden_layers > 1:
            forward_pass.append(nn.functional.relu(self._layers[2](forward_pass[1])))
        if self._total_hidden_layers > 2:
            forward_pass.append(nn.functional.relu(self._layers[3](forward_pass[2])))
        if self._total_hidden_layers > 3:
            forward_pass.append(nn.functional.relu(self._layers[4](forward_pass[3])))
        if self._total_hidden_layers > 4:
            forward_pass.append(nn.functional.relu(self._layers[5](forward_pass[4])))
        # Only 5 hidden layers are currently supported 

        if self._dropout_rate > 0:
            if self._create_residual_connection:
                forward_pass.append(self._layers[-2](forward_pass[-1] + forward_pass[0]))
            else:
                forward_pass.append(self._layers[-2](forward_pass[-1]))

        # Get the logits from the output layer
        logits = self._layers[-1](forward_pass[-1])
        return logits

    def training_step(self, batch, batch_i):
        # Get batch
        x, y = batch
        batch_size = x.size(0)

        # flatten x : batch size * num channels (= 1) * image size (= 28 * 28)
        x = x.view(batch_size, -1)

        # 1. Forward pass
        logits = self(x)

        # 2. Compute objective and other metrics
        J = self._loss(logits, y)

        acc = accuracy(logits, y)
        progress_bar = { 'accuracy': acc }
        return { 'loss': J, 'progress_bar': progress_bar } # 'loss' and 'progress_bar' are pytorch lightning keywords

    def validation_step(self, batch, batch_i):
        results = self.training_step(batch, batch_i)
        return results

    def validates_epoch_end(self, val_step_outputs):
        avg_val_loss = tensor([ o['loss'] for o in val_step_outputs ]).mean()
        avg_accuracy = tensor([ o['progress_bar']['accuracy'] for o in val_step_outputs ]).mean()
        progress_bar = { 'val_accuracy': avg_accuracy }
        return { 'val_loss': avg_val_loss, 'progress_bar': progress_bar } # 'val_loss' and 'progress_bar' are pytorch lightning keywords


In [None]:
mnist_data = datasets.MNIST(
    root = 'data',
    train = True,
    download = True,
    transform = transforms.ToTensor())

In [None]:
# Model architecture parameters
mnist_image_size = mnist_data.data.size()[1] *  mnist_data.data.size()[2]
linear1_dims = 64
linear2_dims = 64
output_dims = 10 # predicting 10 digits

In [None]:
model = LinearImageClassifier(
    training_data = mnist_data,
    linear_dims = [ mnist_image_size, linear1_dims, linear2_dims, output_dims ],
    create_residual_connection = True )

In [None]:
trainer = pl.Trainer(max_epochs = 2)
trainer.fit(model)