In [35]:
import os

import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [34]:
class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

By using the `Trainer` you automatically get:
1. Tensorboard logging
2. Model checkpointing
3. Training and validation loop
4. early-stopping

In [None]:
# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(
    gpus=AVAIL_GPUS,
    max_epochs=3,
    progress_bar_refresh_rate=20,
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /__w/1/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /__w/1/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /__w/1/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /__w/1/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/1/s/.datasets/MNIST/raw

Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Done!


Missing logger folder: /__w/1/s/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K 
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

## MNIST Lightning Module Example


---

### Note what the following built-in functions are doing:

1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#prepare-data) 💾
    - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
    - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)

2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#setup) ⚙️
    - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
    - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
    - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).
    - **Note this runs across all GPUs and it *is* safe to make state assignments here**

3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/stable/api_references.html#core-api) ♻️
    - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`

In [54]:
class MNISTDataModule(LightningDataModule):
    def __init__(
        self,
        batch_size =  256,
        data_dir=PATH_DATASETS
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

In [61]:
class MNISTClassifier(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims


        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'loss' : loss}
    
    def training_epoch_end(self, outputs):

        mean_loss = torch.stack([x["loss"] for x in outputs]).mean()

        self.logger.experiment.add_scalar(
            "Training/loss", mean_loss, self.current_epoch
        )
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        # self.log("val_loss", loss, prog_bar=True)
        # self.log("val_acc", self.accuracy, prog_bar=True)
        return {'loss' : loss, 'accuracy' : acc}
    
    def validation_epoch_end(self, outputs):

        mean_loss = torch.stack([x["loss"] for x in outputs]).mean()
        mean_acc = torch.stack([x["accuracy"] for x in outputs]).mean()
        

        self.logger.experiment.add_scalar(
            "Validation/loss", mean_loss, self.current_epoch
        )
        self.logger.experiment.add_scalar(
            "Validation/accuracy", mean_acc, self.current_epoch
        )
        
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self, outputs):
        mean_loss = torch.stack([x["loss"] for x in outputs]).mean()
        mean_acc = torch.stack([x["accuracy"] for x in outputs]).mean()
        # print(mean_acc, mean_loss)
        self.log("val_loss", mean_loss, prog_bar=True)
        self.log("val_acc", mean_acc, prog_bar=True)
        # return {'loss' : mean_loss, 'accuracy' : mean_acc}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer




In [62]:
model = MNISTClassifier()
data_module = MNISTDataModule()
trainer = Trainer(
    gpus=AVAIL_GPUS,
    max_epochs=3,
    progress_bar_refresh_rate=20,
)
trainer.fit(model,data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | model    | Sequential | 55.1 K
1 | accuracy | Accuracy   | 0     
----------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Epoch 0:  94%|█████████▎| 220/235 [00:07<00:00, 29.02it/s, loss=0.746, v_num=9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████| 235/235 [00:08<00:00, 28.50it/s, loss=0.693, v_num=9]
Epoch 1:  94%|█████████▎| 220/235 [00:07<00:00, 28.48it/s, loss=0.496, v_num=9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 1: 100%|██████████| 235/235 [00:08<00:00, 27.90it/s, loss=0.474, v_num=9]
Epoch 2:  94%|█████████▎| 220/235 [00:07<00:00, 29.29it/s, loss=0.426, v_num=9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 2: 100%|██████████| 235/235 [00:08<00:00, 28.82it/s, loss=0.404, v_num=9]
Epoch 2: 100%|██████████| 235/235 [00:08<00:00, 28.78it/s, loss=0.404, v_num=9]


### Testing

To test a model, call `trainer.test(model)`.

Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically
test using the best saved checkpoint (conditioned on val_loss).

In [63]:
trainer.test(model,data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 100%|██████████| 40/40 [00:01<00:00, 30.74it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': 0.9122070670127869, 'val_loss': 0.31087276339530945}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 40/40 [00:01<00:00, 30.19it/s]


[{'val_loss': 0.31087276339530945, 'val_acc': 0.9122070670127869}]

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!

In [65]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --host "0.0.0.0" --port 1997 --logdir lightning_logs/ 

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 1997 (pid 11662), started 0:07:17 ago. (Use '!kill 11662' to kill it.)