Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] Trainer.test points to latest best_model_path #5161

Merged
merged 33 commits into from Jan 5, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Dec 16, 2020

What does this PR do?

Fixes #5091 #5318 #5288

The test for PipeRCP was failing. This PR also resolves it.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified; Bugfixes should be including in bug-fix release milestones (m.f.X) and features should be included in (m.X.b) releases.

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton self-assigned this Dec 16, 2020
@tchaton tchaton added this to the 1.1.x milestone Dec 16, 2020
@tchaton tchaton added the checkpointing Related to checkpointing label Dec 16, 2020
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Running special tests
set -e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed special_tests.sh doesn't return an error when a test fails. I found set -e might.

}

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]
self.dirpath = checkpointed_state.get("dirpath", self.dirpath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if dirpath is changed when Trainer is reinitialized with a new checkpoint callback??

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I don't see it being used anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question ! I had a chat with Adrian. It is more a design choice. Let's say we want to fine-tune a model from a given checkpoint. It would make sense for the new checkpoint to be saved in the same folder. Happy to brainstorm on this one.

Copy link
Contributor

@rohitgr7 rohitgr7 Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are more things we need to restore like save_top_k, best_k_models, etc. So I'd suggest to keep it for another PR. Also need to take care of the conditions when the new checkpoint instance have different dirpath when doing finetuning. There is an open issue for the same.

Comment on lines 56 to 67
def resolve_resume_from_checkpoint(self):
if not self._trainer_has_checkpoint_callbacks():
return self.trainer.resume_from_checkpoint
checkpoint_callbacks = self.trainer.checkpoint_callbacks[0]
if os.path.exists(checkpoint_callbacks.best_model_path):
resume_from_checkpoint_options = [
checkpoint_callbacks.best_model_path,
self.trainer.resume_from_checkpoint
]
resume_from_checkpoint_options.sort()
return resume_from_checkpoint_options[-1]
return self.trainer.resume_from_checkpoint
Copy link
Contributor

@rohitgr7 rohitgr7 Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct. Here if I set test(ckpt_path=some_checkpoint_path), it will possibly reload the best_model_path again during restore. Why does it even do reload from the resume_checkpoint while doing testing?

I think to better resolve this issue this should be fixed/refactored correctly:
https://github.com/PyTorchLightning/pytorch-lightning/blob/176735097ab5be9ee21d3e7a3dedc174f3e0dd3f/pytorch_lightning/accelerators/gpu_accelerator.py#L61-L69
since while training/testing setup_training is called, which contains few things that are required during both training & testing and others required only during training(for eg restoring checkpoint).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @rohitgr7.

Here was the problem I was trying to resolve.

  1. Load a checkpoint from resume_from_checkpoint .
  2. Fine-tune the model for several epochs which might create a new best_model_path checkpoint
  3. When calling trainer.test(), it should use the new best_model_path and not resume_from_checkpoint.

If I understand properly, you are saying we should skip this restore from resume_from_checkpoint in .test and load directly from best_model_path ?

Best,
T.C

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, best checkpoint or any other checkpoint, it will be loaded here https://github.com/PyTorchLightning/pytorch-lightning/blob/d1e97a4f114a285349e31e330c7bf8937bc1ee04/pytorch_lightning/trainer/trainer.py#L770-L785, so we can just skip the .restore call. Also a few more hooks are called while doing .test like on_pretrain_routine_start and on_pretrain_routine_end since it calls setup_training when doing either train/test, which is incorrect too IMO. So I'd suggest to split/refactor setup_training a bit.

@codecov
Copy link

codecov bot commented Dec 23, 2020

Codecov Report

Merging #5161 (16ccc66) into master (062800a) will increase coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #5161   +/-   ##
======================================
  Coverage      93%     93%           
======================================
  Files         134     134           
  Lines        9970    9976    +6     
======================================
+ Hits         9286    9294    +8     
+ Misses        684     682    -2     

tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/connectors/callback_connector.py Outdated Show resolved Hide resolved
Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One general question: If I pretrain on one PC, does this mean, I cannot finetune on another if the filesustem isn't the same?

@pep8speaks
Copy link

pep8speaks commented Jan 4, 2021

Hello @tchaton! Thanks for updating this PR.

Line 26:121: E501 line too long (132 > 120 characters)

Comment last updated at 2021-01-05 08:52:09 UTC

pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/checkpointing/test_trainer_checkpoint.py Outdated Show resolved Hide resolved
tests/plugins/test_ddp_sequential_plugin.py Outdated Show resolved Hide resolved
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing chnagelog

pytorch_lightning/plugins/rpc_plugin.py Outdated Show resolved Hide resolved
tests/plugins/test_ddp_sequential_plugin.py Outdated Show resolved Hide resolved
@tchaton
Copy link
Contributor Author

tchaton commented Jan 4, 2021

missing chnagelog

Done !

@tchaton tchaton enabled auto-merge (squash) January 4, 2021 20:45
@tchaton tchaton merged commit d5b3678 into master Jan 5, 2021
@tchaton tchaton deleted the bugfix/5091_resume_from_checkpoint_test branch January 5, 2021 10:02
@carmocca carmocca mentioned this pull request Jan 5, 2021
Borda pushed a commit that referenced this pull request Jan 6, 2021
* resolve bug

* update code

* add set -e

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* update test

* Update tests/checkpointing/test_trainer_checkpoint.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* Update tests/checkpointing/test_trainer_checkpoint.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* update on comments

* resolve test

* convert to set

* update

* add error triggering

* update

* update on comments

* update

* resolve import

* update

* update

* Update pytorch_lightning/plugins/rpc_plugin.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit d5b3678)
@ananthsub
Copy link
Contributor

ananthsub commented Jan 12, 2021

@tchaton by not restoring when testing, this use case breaks:

model = BoringModel()
callback_that_implements_on_load_checkpoint = MyCallback()
trainer = Trainer(
    default_root_dir=root_dir,
    max_steps=1,
    callbacks=[callback_that_implements_on_load_checkpoint],
    resume_from_checkpoint=some_dummy_path,
)
trainer.test(model)

this is because we skip calling on_load_checkpoint() for all callbacks here: https://github.com/PyTorchLightning/pytorch-lightning/blob/1f6236accce78303249c55de656b71501e607d1a/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L149-L150

A concrete use case to motivate this is we developed an Exponential Moving Average Callback. We load the EMA state in a callback with the on_load_checkpoint hook to set the state. Then we can use the model with EMA weights for testing.

I think we should be able to restore other states when resuming for testing. WDYT?

@rohitgr7
Copy link
Contributor

seems reasonable to load the trainer & callback states at least 👍, or should we have trainer.test(..., restore_states=True/False)??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trainer.test() in combination with resume_from_checkpoint is broken
10 participants