Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best model in ModelCheckpoint #1960

Closed
mpariente opened this issue May 26, 2020 · 8 comments
Closed

Best model in ModelCheckpoint #1960

mpariente opened this issue May 26, 2020 · 8 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@mpariente
Copy link
Contributor

馃殌 Feature

Add a best_model attribute in ModelCheckpoint.

Motivation

After training, it would be nice to easily have the path to the checkpoint with the best val loss.

Or is there an argument in the trainer to resume best state after training and I missed it? (totally possible)

@mpariente mpariente added feature Is an improvement or enhancement help wanted Open to be worked on labels May 26, 2020
@mpariente
Copy link
Contributor Author

BTW, this is my solution but I don't think this should be necessary

best_k = checkpoint.best_k_models
best_path = [b for b, v in best_k.items() if v == torch.min(best_k.values())][0]

@rohitgr7
Copy link
Contributor

@mpariente Better condition will be:

best_k = checkpoint.best_k_models
best_path = [b for b, v in best_k.items() if v == checkpoint.best)][0]

checkpoint.modecan be max though.

@mpariente
Copy link
Contributor Author

True, thanks, but this still doesn't change the fact that a best_model_path attribute should be integrated in the checkpoint IMO

@HansBambel
Copy link
Contributor

HansBambel commented May 27, 2020

The save_top_k parameter saves the last k best models. When put to -1 it saves all models. The latest master version also has a save_last parameter that also keeps the latest epoch of the model.
E.g:

checkpoint_callback = ModelCheckpoint(
        filepath=os.getcwd()+"/"+"<your-folder>"+"/{epoch}-{val_loss:.6f}",
        save_top_k=1,
        verbose=False,
        monitor='val_loss',
        mode='min',
        prefix=net.__class__.__name__+"_"
    )

This results in a file called: UNet_epoch=5-val_loss=0.581735.ckpt

@mpariente
Copy link
Contributor Author

Thanks for the example, I also use it like that.
Still, asking the user to reverse the dict or look for the best loss in the filenames is not super user-friendly. It asks few (not many) additional lines of code which are not necessary, don't you think?

@williamFalcon
Copy link
Contributor

williamFalcon commented May 27, 2020

this is already in master.

#1799.

@kepler can you add to docs? (on the checkpoint page)

@mpariente
Copy link
Contributor Author

Oh ok, sorry for bothering then, I should have checked master
Thanks for the pointer William

@williamFalcon
Copy link
Contributor

it's a great idea :) just someone beat you to it haha

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

4 participants