In [31]:
import torch
from pytorch_lightning import LightningModule, Trainer,LightningDataModule
from pytorch_lightning.metrics.functional import accuracy, confusion_matrix
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import os
from datetime import datetime
from pl_bolts.callbacks import PrintTableMetricsCallback

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [32]:
class LitMNIST(LightningModule):

    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer



In [33]:
class MNISTDataModule(LightningDataModule):

    def __init__(self, data_dir, str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, )),
        ])
        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [34]:
# Init DataModule
dm = MNISTDataModule('MNIST')
# Init model from datamodule's attributes
model = LitMNIST(*dm.size(), dm.num_classes)
# Init trainer
trainer = Trainer(
    max_epochs=5,
    progress_bar_refresh_rate=20,
    gpus=1,
    callbacks=[PrintTableMetricsCallback()]
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [35]:
#Pass the datamodule as arg to trainer.fit
start = datetime.now()
print('Training started at', start)
trainer.fit(model, dm)
print('Training duration:', datetime.now() - start)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Training started at 2021-08-08 13:20:58.720823


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123


Validating: 0it [00:00, ?it/s]

val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458


Validating: 0it [00:00, ?it/s]

val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
0.26309934258461│0.9187999963760376


Validating: 0it [00:00, ?it/s]

val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
0.26309934258461│0.9187999963760376
0.2321128100156784│0.9291999936103821
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
0.26309934258461│0.9187999963760376
0.2321128100156784│0.9291999936103821
0.2321128100156784│0.9291999936103821


Validating: 0it [00:00, ?it/s]

val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
0.26309934258461│0.9187999963760376
0.2321128100156784│0.9291999936103821
0.2321128100156784│0.9291999936103821
0.21075111627578735│0.9354000091552734
val_loss│val_acc
────────────────
2.3002967834472656│0.09375
0.42443785071372986│0.885200023651123
0.42443785071372986│0.885200023651123
0.30732351541519165│0.9057999849319458
0.30732351541519165│0.9057999849319458
0.26309934258461│0.9187999963760376
0.26309934258461│0.9187999963760376
0.2321128100156784│0.9291999936103821
0.2321128100156784│0.9291999936103821
0.21075111627578735│0.9354000091552734
0.21075111627578735│0.9354000091552734


Training duration: 0:00:33.635890


In [36]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Launching TensorBoard...