Skip to content

Commit

Permalink
resolve bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 1, 2021
1 parent 13f67ad commit 34b8356
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Expand Up @@ -118,11 +118,13 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")

# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self.global_rank == 0:
time.sleep(2)

self.barrier("end-process")

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
Expand Down Expand Up @@ -158,16 +160,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
self.mp_queue.put(results)

def save(self, state_dict: Dict, path: str) -> None:
"""
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
The rendez-vous doesn't affect directly saving.
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
"""
try:
xm.save(state_dict, path)
except RuntimeError as e:
if "Failed to meet rendezvous" not in str(e):
raise e
xm.save(state_dict, path)

def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
Expand Down

0 comments on commit 34b8356

Please sign in to comment.