diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 78bc22d21589d..8d19655175f8e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -222,7 +222,7 @@ def tpu_train(self, tpu_core_idx, model): self.run_pretrain_routine(model) # when training ends on these platforms dump weights to get out of the main process - if self.on_colab_kaggle: + if self.on_colab_kaggle and not self.testing: rank_zero_warn('cleaning up... please do not interrupt') self.save_spawn_weights(model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c104ab8b2f78d..c91b6a63a04d5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1014,7 +1014,7 @@ def fit( xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method) # load weights if not interrupted - if self.on_colab_kaggle: + if self.on_colab_kaggle and not self.testing: self.load_spawn_weights(model) self.model = model diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index b5cb7ca0c756e..1091a4cf3a8dd 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -15,7 +15,9 @@ @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ + """ + Test that None in checkpoint callback is valid and that chkp_path is set correctly + """ tutils.reset_seed() model = EvalModelTemplate()