In [1]:
import pytorch_lightning as pl
import torch 
import torch.nn as nn 

from torchmetrics import __version__ as torchmetrics_version
from pkg_resources import parse_version

from torchmetrics import Accuracy

### Lightning class

In [51]:
class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape = (1,28,28), hidden_units = (32,16)):
        super().__init__()
        self.train_acc = Accuracy(task = "multiclass", num_classes = 10)
        self.valid_acc = Accuracy(task = "multiclass", num_classes = 10)
        self.test = Accuracy(task = "multiclass", num_classes = 10)
        
        input_size = image_shape[0]*image_shape[1]*image_shape[2]
        all_layers = [nn.Flatten()]
        
        for hidden_unit in hidden_units:
            layer = nn.Linear(input_size, hidden_unit)
            all_layers.append(layer)
            all_layers.append(nn.ReLU())
            input_size = hidden_unit
        
        all_layers.append(nn.Linear(hidden_units[-1], 10))
        self.model = nn.Sequential(*all_layers)

    
    def forward(self,x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits,y)
        preds = torch.argmax(logits, dim = 1 )
        self.train_acc.update(preds, y)
        self.log("train_loss", loss, prog_bar = True)
        return loss
    
    def train_epoch_end(self,outs):
        self.log("train_acc", self.train_acc.compute())
        self.train_acc.reset()
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        return loss
    
    def validation_epoch_end(self, outs):
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer
    
        

In [48]:
model = MultiLayerPerceptron()

In [50]:
model

MultiLayerPerceptron(
  (train_acc): MulticlassAccuracy()
  (valid_acc): MulticlassAccuracy()
  (test): MulticlassAccuracy()
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=32, bias=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=16, bias=True)
    (4): ReLU()
    (5): Linear(in_features=16, out_features=10, bias=True)
  )
)

In [26]:
x = torch.rand(10,1*28,28)

In [44]:
x = torch.rand(2,3)

In [45]:
x

tensor([[0.3420, 0.2535, 0.7935],
        [0.0587, 0.8880, 0.5211]])

In [46]:
torch.argmax(x, dim = 1)

tensor([2, 1])

In [52]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
 
from torchvision.datasets import MNIST
from torchvision import transforms