In [7]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms

In [8]:
class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1,28,28),hidden_units=(32,16)):
        super().__init__()
        #new PL attributes
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()

        #Modelling using all layers
        input_size = image_shape[0]*image_shape[1]*image_shape[2]
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units:
            layer = nn.Linear(input_size,hidden_unit)
            all_layers.append(layer)
            all_layers.append(nn.ReLU())
            input_size = hidden_unit
        all_layers.append(nn.Linear(hidden_units[-1],10))
        all_layers.append(nn.Softmax(dim=1))
        self.model = nn.Sequential(*all_layers)

    def forward(self,x):
        x = self.model(x)
        return x

    def training_step(self,batch,batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.log("train loss",loss,prog_bar=True)
        return loss

    def training_epoch_end(self,outs):
        self.log("train_acc",self.train_acc.compute())

    def validation_step(self,batch,batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss",loss,prog_bar=True)
        self.log("valid_acc",self.valid_acc.compute(),prog_bar=True)
        return loss

    def test_step(self,batch,batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss",loss,prog_bar=True)
        self.log("test_acc",self.test_acc.compute(),prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr=0.001)
        return optimizer

In [9]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self,data_path="../data/"):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
        MNIST(root=self.data_path,download=False)

    def setup(self, stage=None):
        #state is either fit, validate, test or predict
        mnist_all = MNIST(
            root=self.data_path,
            train=True,
            transform=self.transform,
            download=False
        )
        self.train, self.val = random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )

        self.test = MNIST(
            root=self.data_path,
            train=False,
            transform=self.transform,
            download=False
        )

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64,num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64,num_workers=6)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64)


In [10]:
torch.manual_seed(1)
mnist_dm = MnistDataModule()

In [11]:
mnistclassifier = MultiLayerPerceptron()
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=10,gpus=1)
else:
    trainer = pl.Trainer(max_epochs=10)

trainer.fit(model=mnistclassifier,datamodule=mnist_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type       | Params
-----------------------------------------
0 | train_acc | Accuracy   | 0     
1 | valid_acc | Accuracy   | 0     
2 | test_acc  | Accuracy   | 0     
3 | model     | Sequential | 25.8 K
-----------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 939/939 [01:14<00:00, 12.59it/s, loss=1.5, v_num=3, train loss=1.600, valid_loss=1.520, valid_acc=0.927]  


In [12]:
trainer.test(model=mnistclassifier,datamodule=mnist_dm)

Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 185.70it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9388871788978577
        test_loss           1.5170652866363525
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.5170652866363525, 'test_acc': 0.9388871788978577}]