In [1]:
import os
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.classification import Accuracy

import wandb
from pytorch_lightning.loggers import WandbLogger



Important: `wandb login` / `!wandb login` with your API key is required

In [2]:
artifacts_dir = os.getcwd() + '/artifacts'

wandb_dir = artifacts_dir + '/wandb'
os.makedirs(wandb_dir, exist_ok=True)
wandb.init(dir=wandb_dir)

[34m[1mwandb[0m: Currently logged in as: [33mgleblion1[0m ([33mdlhf[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        MNIST(artifacts_dir, train=True, download=True)
        MNIST(artifacts_dir, train=False, download=True)

    def setup(self, stage):
        # transforms
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        mnist_train = MNIST(artifacts_dir, train=True, transform=transform)

        # split dataset
        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        self.mnist_test = MNIST(artifacts_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [4]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Linear(28 * 28, 128)
        self.layer2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        return x

class Classifier(pl.LightningModule):
    def __init__(self, model, num_classes=10):
        super().__init__()
        self.model = model
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.valid_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)

        train_acc = self.train_acc(logits, y)
        # self.log('train_acc_step', train_acc, prog_bar=True)

        return loss

    def on_train_epoch_end(self):
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        val_loss = F.cross_entropy(logits, y)
        self.log('val_loss', val_loss, prog_bar=True)
        self.valid_acc.update(logits, y)

    def on_validation_epoch_end(self):
        self.log('valid_acc_epoch', self.valid_acc.compute(), prog_bar=True)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        test_loss = F.cross_entropy(logits, y)
        self.log('test_loss', test_loss, prog_bar=True)
        self.test_acc.update(logits, y)
        return test_loss

    def on_test_epoch_end(self):
        self.log('test_acc_epoch', self.test_acc.compute(), prog_bar=True)
        self.test_acc.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Train the classifier

In [5]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath=artifacts_dir + '/best_models'
)

In [6]:
batch_size = 32
dm = MNISTDataModule(batch_size=batch_size)
wandb_logger = WandbLogger(project='fashion-mnist-test', log_model='all')
wandb_logger.experiment.config["batch_size"] = batch_size

model = Model()
classifier = Classifier(model)
wandb_logger.watch(classifier, log='all')

accelerator = "cpu"
trainer = pl.Trainer(
    accelerator=accelerator,
    limit_train_batches=750,
    max_epochs=5, 
    logger=wandb_logger, 
    callbacks=[checkpoint_callback]
)
trainer.fit(classifier, dm)

  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | Model              | 101 K 
1 | train_acc | MulticlassAccuracy | 0     
2 | valid_acc | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                                                                               | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                                                                                        

  rank_zero_warn(


Epoch 4: 100%|████████████████████████| 907/907 [00:10<00:00, 85.51it/s, loss=0.0649, v_num=1ptx, val_loss=0.123, valid_acc_epoch=0.965]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|████████████████████████| 907/907 [00:10<00:00, 83.52it/s, loss=0.0649, v_num=1ptx, val_loss=0.123, valid_acc_epoch=0.965]


Test the classifier

In [7]:
trainer.test(classifier, dm)

  rank_zero_warn(


Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████| 313/313 [00:02<00:00, 124.06it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.9653000235557556
        test_loss           0.11786127835512161
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.11786127835512161, 'test_acc_epoch': 0.9653000235557556}]

Finish logger

In [8]:
wandb.finish()



0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█
test_acc_epoch,▁
test_loss,▁
train_loss,█▄▄▅▃▂▄▃▄▄▂▂▂▂▃▂▃▃▁▂▁▂▂▁▂▃▁▁▁▁▂▁▂▂▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val_loss,█▄▂▁▁
valid_acc_epoch,▁▆▇██

0,1
epoch,5.0
test_acc_epoch,0.9653
test_loss,0.11786
train_loss,0.0148
trainer/global_step,3750.0
val_loss,0.12333
valid_acc_epoch,0.9652


Load from checkpoint

In [9]:
# classif = Classifier.load_from_checkpoint("/content/lightning_logs/version_7/checkpoints/epoch=77-step=134082.ckpt", model=Model())
# trainer.test(classifier, dm)

Continue training from a checkpoint

In [10]:
# classif = Classifier(Model())
# trainer = pl.Trainer()
# # automatically restores model, epoch, step, LR schedulers, etc...
# trainer.fit(classif, dm, ckpt_path="/content/lightning_logs/version_7/checkpoints/epoch=77-step=134082.ckpt")