## PyTorch Lightning DataModules¶



### Setup

Lightning is easy to install. Simply `pip install pytorch-lightning`.

### Introduction

First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`.

In [1]:
import os

import pytorch_lightning as pl
import torch
import torch.nn.functional as F

from pytorch_lightning.metrics.functional import accuracy
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10

In [2]:
PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

### Defining the LitMNISTModel

Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.

Unfortunately, we have hardcoded dataset-specific items within the model, forever limiting it to working with MNIST Data. 😢

This is fine if you don't plan on training/evaluating your model on different datasets. However, in many cases, this can become bothersome when you want to try out your architecture with different datasets.

In [3]:
class LitMNIST(pl.LightningModule):
    
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()
        
        # We hardcode dataset specific stuff here.
        self.data_dir = data_dir
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(channels * width * height, hidden_size), 
            nn.ReLU(), 
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size), 
            nn.ReLU(), 
            nn.Dropout(0.1), 
            nn.Linear(hidden_size, self.num_classes)
        )
        
    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
            
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

### Train the `LitMNIST` Model

In [4]:
model = LitMNIST()

trainer = pl.Trainer(
    max_epochs=3,
    gpus=AVAIL_GPUS,
    progress_bar_refresh_rate=20
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [5]:
print(trainer)

<pytorch_lightning.trainer.trainer.Trainer object at 0x7f8162d6bb20>


In [6]:
trainer.fit(model)


  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




## Using `DataModules`

DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models.

### Defining The `MNISTDataModule`

Let's go over each function in the class below and talk about what they're doing:

- `__init__`
    - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.
    - Defines a `transform` that will be applied across train, val, and test dataset splits.
    - Defines default `self.dims`, which is a tuple returned from `datamodule.size()` that can help you initialize models.

- `prepare_data`
    - This is where we can download the dataset. We point to our desired dataset and ask `torchvision`'s `MNIST` dataset class to download if the dataset isn't found there.
    - Note we do not make any state assignments in this function (i.e. `self.something = ...`).

- `setup`
    - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
    - Setup expects a '`stage`' arg which is used to separate logic for '`fit`' and '`test`'.
    - If you don't mind loading all your datasets at once, you can set up a condition to allow for both '`fit`' related setup and '`test`' related setup to run whenever `None` is passed to stage.
    - Note this runs across all GPUs and it is safe to make state assignments here.

- `x_dataloader`
    - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`.

In [7]:
class MNISTDataModule(pl.LightningDataModule):
    
    def __init__(self, data_dir: str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        
        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        
        self.dims = (1, 28, 28)
        self.num_classes = 10
        
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
            
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

### Defining the dataset agnostic `LitModel`

Below, we define the same model as the `LitMNIST` model we made earlier.

However, this time our model has the freedom to use any input data that we'd like 🔥.

In [8]:
class LitModel(pl.LightningModule):
    
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        
        super().__init__()
        
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )
        
    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

### Training the `LitModel` using the `MNISTDataModule`
Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders.

In [9]:
# Init DataModule
dm = MNISTDataModule()

In [10]:
dm

<__main__.MNISTDataModule at 0x7f814aee8580>

In [11]:
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)

In [12]:
model

LitModel(
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=64, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.1, inplace=False)
    (7): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [13]:
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    gpus=AVAIL_GPUS,
    progress_bar_refresh_rate=20
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [14]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x7f814d86e5e0>

In [15]:
trainer.fit(model, dm)


  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




### Defining the CIFAR10 DataModule
Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset.

In [16]:
class CIFAR10DataModule(pl.LightningDataModule):
    
    def __init__(self, data_dir: str = './'):
        super().__init__()
        
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.dim = (3, 32, 32)
        self.num_classes = 10
        
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
        
    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
            
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)
    
    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)
    
    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)

### Training the `LitModel` using the `CIFAR10DataModule`

The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data.

In [21]:
dm = CIFAR10DataModule()
dm

<__main__.CIFAR10DataModule at 0x7f814aef5f10>

In [24]:
model = LitModel(channels=3, width=32, height=32, num_classes=dm.num_classes, hidden_size=256)
model

LitModel(
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=3072, out_features=256, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.1, inplace=False)
    (7): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [25]:
trainer = pl.Trainer(
    max_epochs=5,
    gpus=AVAIL_GPUS,
    progress_bar_refresh_rate=20
)
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified



  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 855 K 
-------------------------------------
855 K     Trainable params
0         Non-trainable params
855 K     Total params
3.420     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [26]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 34745), started 2:30:31 ago. (Use '!kill 34745' to kill it.)