In [1]:
import torch  
import torch.nn as nn   
import pytorch_lightning as pl 

from torchmetrics import Accuracy 
from torch.utils.data import DataLoader 
from torch.utils.data import random_split 
from torchvision.datasets import MNIST 
from torchvision import transforms
from pytorch_lightning.loggers import TensorBoardLogger

In [2]:
class MultiLayerPerceptron(pl.LightningModule): 

    def __init__(self,image_shape=(1, 28, 28), hidden_units=(32, 16)): 
        super().__init__()      
        self.train_acc = Accuracy(task="multiclass", num_classes=10) 
        self.valid_acc = Accuracy(task="multiclass", num_classes=10) 
        self.test_acc = Accuracy(task="multiclass", num_classes=10)     

        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        all_layers = [nn.Flatten()] 

        for hidden_unit in hidden_units:  
            layer = nn.Linear(input_size, hidden_unit)  
            all_layers.append(layer)  
            all_layers.append(nn.ReLU())  
            input_size = hidden_unit 

        all_layers.append(nn.Linear(hidden_units[-1], 10))  
        all_layers.append(nn.Softmax(dim=1))  
        self.model = nn.Sequential(*all_layers)  


    def forward(self, x): 
        x = self.model(x) 

        return x 
    
    
    def training_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.train_acc.update(preds, y) 
        self.log("train_loss", loss, prog_bar=True) 

        return loss 
    
    
    def on_train_epoch_end(self): 
        self.log("train_acc", self.train_acc.compute())   


    def validation_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.valid_acc.update(preds, y) 
        self.log("valid_loss", loss, prog_bar=True) 
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True) 

        return loss  
    

    def test_step(self, batch, batch_idx): 
        x, y = batch 
        logits = self(x) 
        loss = nn.functional.cross_entropy(self(x), y) 
        preds = torch.argmax(logits, dim=1) 
        self.test_acc.update(preds, y) 
        self.log("test_loss", loss, prog_bar=True) 
        self.log("test_acc", self.test_acc.compute(), prog_bar=True) 

        return loss  
    
    
    def configure_optimizers(self): 
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001) 
        
        return optimizer

In [3]:
class MnistDataModule(pl.LightningDataModule): 

    def __init__(self, data_path='./data'): 
        super().__init__() 
        self.data_path = data_path 
        self.transform = transforms.Compose([transforms.ToTensor()])  


    def prepare_data(self): 
        MNIST(root=self.data_path, download=True)   


    def setup(self, stage=None): 
        # stage is either 'fit', 'validate', 'test', or 'predict' 
        mnist_all = MNIST(  
            root=self.data_path, 
            train=True, 
            transform=self.transform,   
            download=False 
        )   
        self.train, self.val = random_split( 
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1) 
        )  
        self.test = MNIST(  
            root=self.data_path, 
            train=False, 
            transform=self.transform,   
            download=False 
        )   


    def train_dataloader(self): 
        return DataLoader(self.train, batch_size=64, num_workers=4)  
    

    def val_dataloader(self): 
        return DataLoader(self.val, batch_size=64, num_workers=4)  
    
    
    def test_dataloader(self): 
        return DataLoader(self.test, batch_size=64, num_workers=4)

In [4]:
torch.manual_seed(1)  
mnist_dm = MnistDataModule()

mnistclassifier = MultiLayerPerceptron()  
logger = TensorBoardLogger("lightning_logs/")

if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=10, devices=1, accelerator='gpu', logger=logger)

elif torch.backends.mps.is_available():
    trainer = pl.Trainer(max_epochs=10, devices=1, accelerator='mps', logger=logger)

else: 
    trainer = pl.Trainer(max_epochs=10, logger=logger)
      
trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0      | train
1 | valid_acc | MulticlassAccuracy | 0      | train
2 | test_acc  | MulticlassAccuracy | 0      | train
3 | model     | Sequential         | 25.8 K | train
---------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/nickdinapoli/github/pytorch-playground/.conda/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/nickdinapoli/github/pytorch-playground/.conda/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 9: 100%|██████████| 860/860 [00:09<00:00, 89.43it/s, v_num=0, train_loss=1.590, valid_loss=1.520, valid_acc=0.927] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [00:09<00:00, 89.36it/s, v_num=0, train_loss=1.590, valid_loss=1.520, valid_acc=0.927]


In [5]:
# Run in terminal or here: 
!tensorboard --logdir lightning_logs/

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.18.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [7]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 28314), started 0:00:06 ago. (Use '!kill 28314' to kill it.)