From 366267a1d426933d80f480a78734f8c268fab987 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 28 Jun 2020 00:20:04 +0200 Subject: [PATCH 1/3] torchtext --- pytorch_lightning/utilities/apply_func.py | 14 ++++++++++---- requirements/base.txt | 3 +-- requirements/extra.txt | 3 ++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b5ec664b2abd0..7f0052f9c608e 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,10 +1,14 @@ from abc import ABC from collections import Mapping, Sequence +from copy import copy from typing import Any, Callable, Union import torch -from torchtext.data import Batch -from copy import copy + +try: + from torchtext.data import Batch +except ImportError: + Batch = None def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: @@ -86,9 +90,11 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ - def batch_to(data): - if isinstance(data, Batch): + if Batch is None: + raise ImportError('You want to use `torchtext` package which is not installed yet,' + ' install it with `pip install torchtext`.') + elif isinstance(data, Batch): # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: diff --git a/requirements/base.txt b/requirements/base.txt index e045fab912919..64375637e1711 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,5 +6,4 @@ torch>=1.3 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 -PyYAML>=5.1 # OmegaConf requirement -torchtext>=0.3.1 \ No newline at end of file +PyYAML>=5.1 # OmegaConf requirement \ No newline at end of file diff --git a/requirements/extra.txt b/requirements/extra.txt index a9d8b6bdf486d..71a9ea9ccc6d3 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -10,4 +10,5 @@ matplotlib>=3.1.1 horovod>=0.19.1 omegaconf>=2.0.0 # scipy>=0.13.3 -scikit-learn>=0.20.0 \ No newline at end of file +scikit-learn>=0.20.0 +torchtext>=0.3.1 \ No newline at end of file From 31d14b4df01b205a615c70f13d4d72c701f151c7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 27 Jun 2020 19:32:36 -0400 Subject: [PATCH 2/3] Update pytorch_lightning/utilities/apply_func.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/apply_func.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 7f0052f9c608e..b2d758eea1bcb 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -5,10 +5,10 @@ import torch -try: +import importlib +TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None +if TORCHTEXT_AVAILABLE: from torchtext.data import Batch -except ImportError: - Batch = None def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: From 64afcc1e8cf4a7e2824b580de280fcac00a6ade5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 27 Jun 2020 19:43:10 -0400 Subject: [PATCH 3/3] Update apply_func.py --- pytorch_lightning/utilities/apply_func.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b2d758eea1bcb..2003bca6fd2e1 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -91,10 +91,9 @@ def move_data_to_device(batch: Any, device: torch.device): - :class:`torch.device` """ def batch_to(data): - if Batch is None: - raise ImportError('You want to use `torchtext` package which is not installed yet,' - ' install it with `pip install torchtext`.') - elif isinstance(data, Batch): + # try to move torchtext data first + if TORCHTEXT_AVAILABLE and isinstance(data, Batch): + # Shallow copy because each Batch has a reference to Dataset which contains all examples device_data = copy(data) for field in data.fields: @@ -102,6 +101,7 @@ def batch_to(data): device_field = getattr(data, field).to(device, non_blocking=True) setattr(device_data, field, device_field) return device_data + else: + return data.to(device, non_blocking=True) - return data.to(device, non_blocking=True) return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)