Skip to content

Commit

Permalink
use parallel loader
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed May 31, 2020
1 parent fdbbe96 commit ed6e758
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Expand Up @@ -135,6 +135,7 @@
from pytorch_lightning.utilities import rank_zero_warn

try:
import torch_xla
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
except ImportError:
Expand Down Expand Up @@ -249,8 +250,8 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
dl_outputs = []

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device()
if self.use_tpu:
device = torch_xla._XLAC._xla_get_default_device()
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Expand Up @@ -167,6 +167,7 @@ def training_step(self, batch, batch_idx):
APEX_AVAILABLE = True

try:
import torch_xla
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
except ImportError:
Expand Down Expand Up @@ -412,8 +413,8 @@ def run_training_epoch(self):
train_dataloader = self.train_dataloader

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device()
if self.use_tpu:
device = torch_xla._XLAC._xla_get_default_device()
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

Expand Down

0 comments on commit ed6e758

Please sign in to comment.