# Using DALI in PyTorch Lightning

### Overview

This example shows how to use DALI in PyTorch Lightning.

Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) of the clasification network and let us see how DALI can accelerate it.

DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out.

In [1]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import os

BATCH_SIZE = 64

Now let us implement the bare training class with the native data loader

In [2]:
class LitMNIST(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    x = F.log_softmax(x, dim=1)
    return x

  def process_batch(self, batch):
      return batch

  def training_step(self, batch, batch_idx):
      x, y = self.process_batch(batch)
      logits = self(x)
      loss = F.nll_loss(logits, y)
      return loss

  def cross_entropy_loss(self, logits, labels):
      return F.nll_loss(logits, labels)

  def configure_optimizers(self):
      return Adam(self.parameters(), lr=1e-3)

  def prepare_data(self):# transforms for images
      transform=transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.1307,), (0.3081,))])
      self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)

  def train_dataloader(self):
       return DataLoader(self.mnist_train, batch_size=64, num_workers=8, pin_memory=True)


And see how it works

In [3]:
model = LitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
GeForce GT 710 with CUDA capability sm_35 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75.
If you want to use the GeForce GT 710 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/


  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   
I1014 17:16:18.776570 139719591925568 lightning.py:1215] 
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

Now let us define a DALI pipeline which would load the data.

In [4]:
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIClassificationIterator

# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')

class MnistPipeline(Pipeline):
    def __init__(self, batch_size, device, device_id=0, shard_id=0, num_shards=1, num_threads=4, seed=0):
        super(MnistPipeline, self).__init__(
            batch_size, num_threads, device_id, seed)
        self.device = device
        self.reader = ops.Caffe2Reader(path=data_path, shard_id=shard_id, num_shards=num_shards, random_shuffle=True)
        self.decode = ops.ImageDecoder(
            device='mixed' if device == 'gpu' else 'cpu',
            output_type=types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device=device,
            dtype=types.FLOAT,
            std=[0.3081 * 255],
            mean=[0.1307 * 255],
            output_layout="CHW")
        self.to_int64 = ops.Cast(dtype=types.INT64, device=device)

    def define_graph(self):
        inputs, labels = self.reader(name="Reader")
        images = self.decode(inputs)
        images = self.cmn(images)
        if self.device == "gpu":
            labels = labels.gpu()
        # PyTorch expects labels as INT64
        labels = self.to_int64(labels)

        return (images, labels)

Add DALI to the data preparation step in the training class and adjust the `process_batch` to accept the data returned by DALIClassificationIterator, which returns a list of dictionaries, where each list element corresponds to one pipeline wrapped by the DALIIterator, and entries in the dictionary corresponds to the relevant outputs. Check for details in the DALIGenericIterator documenation.

In [5]:
class DALILitMNIST(LitMNIST):
    def __init__(self):
        super().__init__()
    
    def prepare_data(self):
        device_id = self.local_rank
        shard_id = self.global_rank
        num_shards = self.trainer.world_size
        mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=device_id, shard_id=shard_id,
                                       num_shards=num_shards, num_threads=8)
        self.train_loader = DALIClassificationIterator(mnist_pipeline, reader_name="Reader",
                                                       fill_last_batch=False, auto_reset=True)                          
    def train_dataloader(self):
        return self.train_loader
    
    def process_batch(self, batch):
        x = batch[0]["data"]
        y = batch[0]["label"].squeeze().cuda().long()
        return (x, y)

Now, let us run the training:

In [6]:
model = DALILitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)

GPU available: True, used: True
I1014 17:16:44.541009 139719591925568 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1014 17:16:44.543509 139719591925568 distributed.py:49] TPU available: False, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
I1014 17:16:44.545432 139719591925568 distributed.py:49] Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
I1014 17:16:44.547166 139719591925568 accelerator_connector.py:333] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   
I1014 17:16:44.609724 139719591925568 lightning.py:1215] 
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

Now let us provide the custom DALI iterator wrapper so we don't have to do any extra processing inside `LitMNIST.process_batch`, also PyTorch can learn how big is the dataset

In [7]:
class BetterDALILitMNIST(LitMNIST):
    def __init__(self):
        super().__init__()
    
    def prepare_data(self):
        device_id = self.local_rank
        shard_id = self.global_rank
        num_shards = self.trainer.world_size
        mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)

        class LightingWrapper(DALIClassificationIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)
            
            def __len__(self):
               return self.size // self.batch_size
    
            def __next__(self):
                out = super().__next__()
                # DALIClassificationIterator calls next during the construction
                # so first brach would be already converted to a list not a dict, no need to post process it
                if isinstance(out[0], dict):
                    # DDP is used so only one pipeline per process
                    # also we need to transform dict returned by DALIClassificationIterator to iterable
                    # and squeeze the lables
                    out = out[0]
                    return [out[k] if k != "label" else torch.squeeze(out[k]) for k in self.output_map]
                else:
                    return out

        self.train_loader = LightingWrapper(mnist_pipeline, reader_name="Reader", fill_last_batch=False, auto_reset=True)
    def train_dataloader(self):
        return self.train_loader

Now, let us rerun the training:

In [8]:
model = BetterDALILitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)

GPU available: True, used: True
I1014 17:17:07.718175 139719591925568 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1014 17:17:07.720351 139719591925568 distributed.py:49] TPU available: False, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
I1014 17:17:07.722338 139719591925568 distributed.py:49] Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
I1014 17:17:07.724163 139719591925568 accelerator_connector.py:333] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   
I1014 17:17:07.770608 139719591925568 lightning.py:1215] 
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1