Skip to content

Commit

Permalink
added broadcast option to tpu (#3814)
Browse files Browse the repository at this point in the history
* added broadcast option to tpu

* add device

* moved tpu broadcast to tpu_backend

* removed Lightning dist

* decode bytes

* pep8 fix

* fix bug

* test for broadcast

* updated changelog
  • Loading branch information
lezwon committed Oct 4, 2020
1 parent 093535d commit 4da240e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425))

- Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import re

Expand All @@ -21,6 +21,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -316,3 +317,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float)
data = xm.all_gather(data_tensor)
buffer = io.BytesIO(data.cpu().byte().numpy())
obj = torch.load(buffer)
return obj
21 changes: 16 additions & 5 deletions tests/models/test_tpu.py
Expand Up @@ -5,20 +5,17 @@

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

SERIAL_EXEC = xmp.MpSerialExecutor()
# TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team
# device = torch_xla.core.xla_model.xla_device()
# device_type = torch_xla.core.xla_model.xla_device_hw(device)
# TPU_AVAILABLE = device_type == 'TPU'
except ImportError:
TPU_AVAILABLE = False
else:
Expand Down Expand Up @@ -272,3 +269,17 @@ def test_result_obj_on_tpu(tmpdir):
)

tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
def test_broadcast(rank):
trainer = Trainer(tpu_cores=8)
backend = TPUBackend(trainer)
obj = ("ver_0.5", "logger_name", rank)
result = backend.broadcast(obj)
assert result == ("ver_0.5", "logger_name", 0)

xmp.spawn(test_broadcast, nprocs=8, start_method='fork')

0 comments on commit 4da240e

Please sign in to comment.