## Data Module

In [1]:
import warnings

In [2]:
warnings.simplefilter("ignore")

In [3]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from pathlib import Path

In [19]:
class MNIST_DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.datadir = Path.home() / "mldata" / "mnist"
        self.batch_size = 256
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
    def prepare_data(self):
        """Called only once and on 1 GPU"""
        MNIST(self.datadir, train=True, download=True)
        MNIST(self.datadir, train=False, download=True)
        
    def setup(self, stage=None):
        """Called on each GPU separately"""
        if stage in [None, "fit", "validate"]:
            trainvalset = MNIST(self.datadir, train=True, transform=self.transform)
            self.trainset, self.valset = random_split(trainvalset, [55000, 5000])
        if stage == "test" or stage is None:
            self.testset = MNIST(self.datadir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        train_dl = DataLoader(self.trainset, batch_size=self.batch_size)
        return train_dl
    
    def val_dataloader(self):
        val_dl = DataLoader(self.valset, batch_size=self.batch_size)
        return val_dl
    
    def test_dataloader(self):
        test_dl = DataLoader(self.testset, batch_size=self.batch_size)
        return test_dl
    

In [20]:
mnist = MNIST_DataModule()

## Model

In [21]:
import torch as t
import torch.nn.functional as F
import torchmetrics as tm
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback

In [22]:
class MNIST_Classifier(LightningModule):
    def __init__(self, n_layer_1=128, n_layer_2=256, lr=1e-3):
        super().__init__()
        
        self.layer_1 = t.nn.Linear(28*28, n_layer_1)
        self.layer_2 = t.nn.Linear(n_layer_1, n_layer_2)
        self.layer_3 = t.nn.Linear(n_layer_2, 10)
        
        self.loss = t.nn.CrossEntropyLoss()
        self.lr = lr
        self.accuracy = tm.Accuracy()
        
        self.save_hyperparameters()
        
    def forward(self, x):
        batch_size, channels, width, height = x.size()
        
        # (b, 1, 28, 28) --> (b, 1*28*28)
        x = x.view(batch_size, -1)
        
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        logits = self.layer_3(x)
        
        return logits
    
    def training_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)
        
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        preds, loss, acc = self._get_preds_loss_accuracy(batch)
        
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        
        # Not needed by default, but I'll use this in a custom callback
        return preds
    
    def test_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)
        
        self.log("test_loss", loss)
        self.log("test_acc", acc)
        
    def configure_optimizers(self):
        return t.optim.Adam(self.parameters(), lr=self.lr)
    
    def _get_preds_loss_accuracy(self, batch):
        x, y = batch
        logits = self(x)
        preds = t.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = self.accuracy(preds, y)
        return preds, loss, acc        

In [23]:
model = MNIST_Classifier(n_layer_1=128, n_layer_2=256)

In [24]:
class LogPredictionsCallback(Callback):
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # `outputs` comes from `LightningModule.validation_step`
        if batch_idx == 0:
            n = 20
            x, y = batch
            examples = []
            for i, (x_i, y_i, y_pred) in enumerate(zip(x[:n], y[:n], outputs[:n])):
                example = wandb.Image(x_i, caption=f"Ground Truth: {y_i}\nPrediction: {y_pred}")
                # pl_module.logger.experiment.log({f"example_{i}": example})
                examples.append(example)
            pl_module.logger.experiment.log({"examples": examples})

## Training Module

In [25]:
from pytorch_lightning import Trainer

In [26]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb

In [27]:
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")

In [28]:
logger = WandbLogger(project="MNIST", log_model="all")

In [29]:
logger.watch(model)

In [30]:
trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, LogPredictionsCallback()],
    max_epochs=5
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [31]:
trainer.fit(model, datamodule=mnist)


  | Name     | Type             | Params
----------------------------------------------
0 | layer_1  | Linear           | 100 K 
1 | layer_2  | Linear           | 33.0 K
2 | layer_3  | Linear           | 2.6 K 
3 | loss     | CrossEntropyLoss | 0     
4 | accuracy | Accuracy         | 0     
----------------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
trainer.test(model, datamodule=mnist)

In [32]:
run_id = trainer.logger.experiment.id
project = trainer.logger.experiment.project
entity = trainer.logger.experiment.entity
name = trainer.logger.experiment.name
print(entity, project, run_id, name)

avilay MNIST 3jsbobir resilient-jazz-4


In [33]:
wandb.finish()

VBox(children=(Label(value=' 7.87MB of 7.87MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_runtime,101.0
_timestamp,1625363129.0
_step,31.0
train_loss,0.01566
train_acc,0.99609
epoch,4.0
trainer/global_step,1074.0
val_loss,0.10381
val_acc,0.971


0,1
_runtime,▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇████
_timestamp,▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇████
_step,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███
train_loss,▇█▄▄▃▄▃▅▂▃▃▃▂▂▃▂▃▂▂▂▁
train_acc,▂▁▅▅▅▅▆▅█▅▆▅▆▆▆▆▆▆▇██
epoch,▁▁▁▁▁▃▃▃▃▃▅▅▅▅▅▆▆▆▆▆▆█████
trainer/global_step,▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
val_loss,█▃▁▁▁
val_acc,▁▆▇██


In [None]:
trainer