In [9]:
import torch
import torch.nn             as nn
import torch.nn.functional  as F
from torch.utils.data       import DataLoader, random_split, Dataset
from torchvision            import transforms
from torchvision.datasets   import MNIST

import lightning                    as     L
from   lightning.pytorch.callbacks  import EarlyStopping
from   lightning.pytorch.callbacks  import TQDMProgressBar, ModelCheckpoint
from   lightning.pytorch.loggers    import TensorBoardLogger, CSVLogger

import numpy            as np
import pandas           as pd

In [10]:
class ConvNet(L.LightningModule):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),             # 28x28 -> 32x26x26
            nn.ReLU(),
            nn.MaxPool2d(2),                    # 32x26x26 -> 32x13x13
            nn.Conv2d(32, 64, 3, 1),            # 32x13x13 -> 64x11x11
            nn.ReLU(),
            nn.MaxPool2d(2),                    # 64x11x11 -> 64x5x5
            nn.Conv2d(64, 128, 3, 1),           # 64x5x5 -> 128x3x3
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, 1),          # 128x3x3 -> 256x1x1
            nn.ReLU(),
            nn.Flatten(),                       # 256x1x1 -> 256
            nn.Linear(256, 128),                # 256 -> 128
            nn.ReLU(),
            nn.Linear(128, 10),                 # 128 -> 10
            # nn.Flatten(),                       # 64x5x5 -> 1600
            # nn.Linear(1600, 128),               # 1600 -> 128
            # nn.ReLU(),
            # nn.Linear(128, 10),                 # 128 -> 10
        )

        self.criteria = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return y_hat, self.criteria(y_hat, y)

    def training_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    
    def validation_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)
        self.log('val_loss', loss)
        # Log accuracy
        pred = torch.argmax(y_hat, dim=1)
        acc = torch.sum(pred == batch[1]).item() / len(pred)
        self.log('val_acc', acc)
        return loss

    def test_step(self, batch, batch_idx):
        y_hat, loss = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        # Log accuracy
        pred = torch.argmax(y_hat, dim=1)
        acc = torch.sum(pred == batch[1]).item() / len(pred)
        self.log('test_acc', acc)
        return loss
    



In [11]:
train_dataset, val_dataset   = random_split(MNIST(root='.', train=True, transform=transforms.ToTensor()), [55000, 5000])
test_dataset                = MNIST(root='.', train=False, transform=transforms.ToTensor())

train_loader                = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=4, persistent_workers=True)
val_loader                  = DataLoader(val_dataset,   batch_size=64, shuffle=False, num_workers=4, persistent_workers=True)
test_loader                 = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=4, persistent_workers=True)
print(f"Train: {len(train_loader)} batches, Val: {len(val_loader)} batches, Test: {len(test_loader)} batches")

Train: 860 batches, Val: 79 batches, Test: 157 batches


In [33]:
model               = ConvNet()

trainer             = L.Trainer(
                            max_epochs=100, accelerator='mps', enable_progress_bar=True, devices=1,
                            callbacks=[
                                EarlyStopping(monitor="val_acc", mode="max", patience=20), 
                                TQDMProgressBar(refresh_rate=10),
                                ModelCheckpoint(
                                                # dirpath='logs/checkpoints/',

                                                filename='{epoch}_{step}_{val_loss:.4f}_val_acc={val_acc:.4f}', 
                                                monitor='val_acc', save_top_k=3, mode='max', save_on_train_epoch_end=False), 
                                ],
                            logger=[
                                    CSVLogger(        "logs/", name='csv/'),
                                    TensorBoardLogger("logs/", name='tensorboard/'),
                                    ]

                            
                            )
# trainer = pl.Trainer(max_epochs=2, accelerator='m
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/csv/

  | Name     | Type             | Params
----------------------------------------------
0 | model    | Sequential       | 422 K 
1 | criteria | CrossEntropyLoss | 0     
----------------------------------------------
422 K     Trainable params
0         Non-trainable params
422 K     Total params
1.688     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
L.

In [26]:
csvlogger.log_dir

'logs/csv/version_3'

In [5]:
trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at lightning_logs/convnet/version_4/checkpoints/epoch=31_step=27520_val_loss=0.0480_val_acc=val_acc=0.9922.ckpt
Loaded model weights from the checkpoint at lightning_logs/convnet/version_4/checkpoints/epoch=31_step=27520_val_loss=0.0480_val_acc=val_acc=0.9922.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9937000274658203
        test_loss          0.043489180505275726
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.043489180505275726, 'test_acc': 0.9937000274658203}]

In [6]:
# load best model
# best_model = ConvNet.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_path)

lightning_logs/convnet/version_4/checkpoints/epoch=31_step=27520_val_loss=0.0480_val_acc=val_acc=0.9922.ckpt


In [7]:

for i in trainer.callback_metrics.keys():
    print(i, trainer.callback_metrics[i].cpu().numpy())

test_loss 0.04348918
test_acc 0.9937


In [8]:
print(trainer.callback_metrics['train_loss'])
# print(trainer.callback_metrics['test_loss'])
print(trainer.callback_metrics['val_loss'])

KeyError: 'train_loss'