Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transfer_batch_to_device hook to DataModule #3038

Merged
merged 5 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 51 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union

import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -306,6 +307,56 @@ def test_dataloader(self):
return loader
"""

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

Example::

def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch

Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.

Returns:
A reference to the data on the new device.

Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).

Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,11 +1112,6 @@ def __attach_datamodule(self, model, datamodule, stage):
# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:

# If datamodule.setup('test') has not been called yet, call it
# if stage == 'test':
# if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test:
# datamodule.setup('test')

# Override loader hooks
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
Expand All @@ -1125,6 +1120,10 @@ def __attach_datamodule(self, model, datamodule, stage):
if self.is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader

# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if self.is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device

self.datamodule = datamodule

def run_pretrain_routine(self, model: LightningModule):
Expand Down
40 changes: 39 additions & 1 deletion tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from argparse import ArgumentParser
from unittest.mock import MagicMock

import pytest
import torch

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
Expand Down Expand Up @@ -317,3 +318,40 @@ def test_full_loop_ddp_spawn(tmpdir):
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_acc'] > 0.8


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine")
def test_dm_transfer_batch_to_device(tmpdir):
class CustomBatch:

def __init__(self, data):
self.samples = data[0]
self.targets = data[1]

class CurrentTestDM(LightningDataModule):

hook_called = False

def transfer_batch_to_device(self, data, device):
self.hook_called = True
if isinstance(data, CustomBatch):
data.samples = data.samples.to(device)
data.targets = data.targets.to(device)
else:
data = super().transfer_batch_to_device(data, device)
return data

model = EvalModelTemplate()
dm = CurrentTestDM()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
if trainer.is_overridden('transfer_batch_to_device', dm):
model.transfer_batch_to_device = dm.transfer_batch_to_device

batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
expected = torch.device('cuda', 0)
assert dm.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected