Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC
from collections import Mapping, Sequence
from copy import copy
from typing import Any, Callable, Union

import torch
from torchtext.data import Batch
from copy import copy

import importlib
TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
if TORCHTEXT_AVAILABLE:
from torchtext.data import Batch


def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -86,16 +90,18 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""

def batch_to(data):
if isinstance(data, Batch):
# try to move torchtext data first
if TORCHTEXT_AVAILABLE and isinstance(data, Batch):

# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
setattr(device_data, field, device_field)
return device_data
else:
return data.to(device, non_blocking=True)

return data.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
3 changes: 1 addition & 2 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ torch>=1.3
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
torchtext>=0.3.1
PyYAML>=5.1 # OmegaConf requirement
3 changes: 2 additions & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ matplotlib>=3.1.1
horovod>=0.19.1
omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
scikit-learn>=0.20.0
torchtext>=0.3.1