In [1]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
 
from torchvision import transforms
from torchvision.datasets import FashionMNIST

import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from torchmetrics import Accuracy, F1Score, AUROC, MetricCollection

In [2]:
torch.set_float32_matmul_precision("medium")

In [3]:
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./data", batch_size: int = 128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.num_workers = os.cpu_count()

    def prepare_data(self):
        FashionMNIST(self.data_dir, train=True, download=True)
        FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        if stage == "fit":
              data_full = FashionMNIST(self.data_dir, train=True, transform=self.transform)
              self.trainset, self.valset = random_split(data_full, [50000, 10000])

        if stage == "test":
              self.testset = FashionMNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.trainset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.valset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.testset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True
        )

In [4]:
class FashionMNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 64, 3)
        self.pool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

        self.metrics = MetricCollection([
            F1Score(task='multiclass', num_classes=10),
            AUROC(task='multiclass', num_classes=10),
            Accuracy(task='multiclass', num_classes=10),
        ])

        self.val_metrics = self.metrics.clone(prefix='val_')
        self.test_metrics = self.metrics.clone(prefix='test_')

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout2(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.val_metrics.update(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        self.log_dict(self.val_metrics.compute(), prog_bar=True)
        self.val_metrics.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        self.test_metrics.update(logits, y)

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics.compute(), prog_bar=True)
        self.test_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=3,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "frequency": 1
            },
        }

In [5]:
pl.seed_everything(42)

Seed set to 42


42

In [6]:
datamodule = FashionMNISTDataModule()
model = FashionMNISTModel()

trainer = pl.Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    callbacks=[
        TQDMProgressBar(refresh_rate=1),
        EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ],
    logger=TensorBoardLogger(save_dir='lightning_logs', name='fashion_mnist'),
    enable_progress_bar=True,
    enable_model_summary=True,
    log_every_n_steps=1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model=model, datamodule=datamodule)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:02<00:00, 11.3MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 485kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 7.12MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 17.7MB/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw




   | Name         | Type             | Params | Mode 
-----------------------------------------------------------
0  | conv1        | Conv2d           | 320    | train
1  | conv2        | Conv2d           | 18.5 K | train
2  | conv3        | Conv2d           | 36.9 K | train
3  | pool         | MaxPool2d        | 0      | train
4  | dropout1     | Dropout2d        | 0      | train
5  | dropout2     | Dropout2d        | 0      | train
6  | fc1          | Linear           | 204 K  | train
7  | fc2          | Linear           | 1.3 K  | train
8  | metrics      | MetricCollection | 0      | train
9  | val_metrics  | MetricCollection | 0      | train
10 | test_metrics | MetricCollection | 0      | train
-----------------------------------------------------------
261 K     Trainable params
0         Non-trainable params
261 K     Total params
1.048     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Epoch 24: 100%|██████████| 391/391 [00:09<00:00, 41.18it/s, v_num=0, train_loss_step=0.0241, val_loss_step=0.0926, val_loss_epoch=0.186, val_MulticlassF1Score=0.935, val_MulticlassAUROC=0.997, val_MulticlassAccuracy=0.935, train_loss_epoch=0.117]

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 391/391 [00:09<00:00, 41.08it/s, v_num=0, train_loss_step=0.0241, val_loss_step=0.0926, val_loss_epoch=0.186, val_MulticlassF1Score=0.935, val_MulticlassAUROC=0.997, val_MulticlassAccuracy=0.935, train_loss_epoch=0.117]


In [8]:
trainer.test(datamodule=datamodule)

Restoring states from the checkpoint path at lightning_logs/fashion_mnist/version_0/checkpoints/epoch=24-step=9775.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/fashion_mnist/version_0/checkpoints/epoch=24-step=9775.ckpt


Testing DataLoader 0: 100%|██████████| 79/79 [00:00<00:00, 92.94it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  test_MulticlassAUROC      0.9964348077774048
 test_MulticlassAccuracy    0.9315999746322632
 test_MulticlassF1Score     0.9315999746322632
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_MulticlassF1Score': 0.9315999746322632,
  'test_MulticlassAUROC': 0.9964348077774048,
  'test_MulticlassAccuracy': 0.9315999746322632}]

In [10]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 649926), started 0:02:30 ago. (Use '!kill 649926' to kill it.)