In [55]:
import torch
import torch.nn             as nn
from torch.utils.data       import DataLoader, random_split
from torchvision            import transforms
from torchvision.datasets   import MNIST

import lightning                    as     L
from   lightning.pytorch.callbacks  import EarlyStopping
from   lightning.pytorch.callbacks  import TQDMProgressBar, ModelCheckpoint
from   lightning.pytorch.loggers    import TensorBoardLogger, CSVLogger

In [56]:

class ConvNet(L.LightningModule):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),             # 28x28 -> 32x26x26
            nn.ReLU(),
            nn.MaxPool2d(2),                    # 32x26x26 -> 32x13x13
            nn.Conv2d(32, 64, 3, 1),            # 32x13x13 -> 64x11x11
            nn.ReLU(),
            nn.MaxPool2d(2),                    # 64x11x11 -> 64x5x5
            nn.Conv2d(64, 128, 3, 1),           # 64x5x5 -> 128x3x3
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, 1),          # 128x3x3 -> 256x1x1
            nn.ReLU(),
            nn.Flatten(),                       # 256x1x1 -> 256
            nn.Linear(256, 128),                # 256 -> 128
            nn.ReLU(),
            nn.Linear(128, 10),                 # 128 -> 10
        )
        self.criteria = nn.CrossEntropyLoss()
        mnist_dataset = MNIST(root='.', train=True, download=True, transform=transforms.ToTensor())
        self.train_dataset, self.val_dataset = random_split(mnist_dataset, [55000, 5000])
        self.test_dataset = MNIST(root='.', train=False, download=True, transform=transforms.ToTensor())

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return y_hat, self.criteria(y_hat, y)

    def training_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def validation_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)

        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        return loss
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=32, num_workers=10, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,   batch_size=32, num_workers=10, shuffle=False, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,  batch_size=32, num_workers=10, shuffle=False, persistent_workers=True)

In [57]:
train_dataset, val_dataset   = random_split(MNIST(root='.', train=True, transform=transforms.ToTensor()), [55000, 5000])
test_dataset                = MNIST(root='.', train=False, transform=transforms.ToTensor())

train_loader                = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=4, persistent_workers=True)
val_loader                  = DataLoader(val_dataset,   batch_size=64, shuffle=False, num_workers=4, persistent_workers=True)
test_loader                 = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=4, persistent_workers=True)
print(f"Train: {len(train_loader)} batches, Val: {len(val_loader)} batches, Test: {len(test_loader)} batches")

Train: 860 batches, Val: 79 batches, Test: 157 batches


In [58]:
model               = ConvNet()
trainer             = L.Trainer(
                            max_epochs=100, enable_progress_bar=True, devices=1,
                            callbacks=[
                                EarlyStopping(monitor="val_loss", patience=5), 
                                TQDMProgressBar(refresh_rate=10),
                                ModelCheckpoint(filename='{epoch}_{step}_{val_loss:.4f}_val_acc={val_acc:.4f}', save_top_k=3, monitor='val_loss')
                                ],
                            logger=[
                                CSVLogger(        "logs/", name='csv/'),
                                TensorBoardLogger("logs/", name='tensorboard/'),
                                ]
                            )
trainer.fit(model)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/csv/

  | Name     | Type             | Params
----------------------------------------------
0 | model    | Sequential       | 422 K 
1 | criteria | CrossEntropyLoss | 0     
----------------------------------------------
422 K     Trainable params
0         Non-trainable params
422 K     Total params
1.688     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [61]:
trainer.test()

Restoring states from the checkpoint path at logs/csv/version_0/checkpoints/epoch=5_step=10314_val_loss=0.0349_val_acc=val_acc=0.0000.ckpt
Loaded model weights from the checkpoint at logs/csv/version_0/checkpoints/epoch=5_step=10314_val_loss=0.0349_val_acc=val_acc=0.0000.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.03595905005931854
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.03595905005931854}]

In [66]:
print(f'Best Checkpoint: {trainer.checkpoint_callback.best_model_path}')

Best Checkpoint: logs/csv/version_0/checkpoints/epoch=5_step=10314_val_loss=0.0349_val_acc=val_acc=0.0000.ckpt


In [15]:
import pandas as pd
import plotly.express as px

df = pd.read_csv("logs/csv/version_0/metrics.csv")
train_losses_step   = df[df['train_loss_step' ].notna()].copy()
train_losses_epoch  = df[df['train_loss_epoch'].notna()].copy()
val_losses          = df[df['val_loss'        ].notna()].copy()
test_losses         = df[df['test_loss'       ].notna()].copy()

train_losses_step['Dataset']  = 'Train (Step)'
train_losses_epoch['Dataset'] = 'Train (Epoch)'
val_losses['Dataset']         = 'Val'
test_losses['Dataset']        = 'Test'

train_losses_step['Loss']    = train_losses_step['train_loss_step']
train_losses_epoch['Loss']   = train_losses_epoch['train_loss_epoch']
val_losses['Loss']           = val_losses['val_loss']
test_losses['Loss']          = test_losses['test_loss']

df = pd.concat([train_losses_step, train_losses_epoch, val_losses, test_losses])
fig = px.line(df, x='step', y='Loss', color='Dataset', title='Losses').show()