Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponential Moving Average (EMA) #8100

Closed
miraodasilva opened this issue Jun 23, 2021 · 6 comments
Closed

Exponential Moving Average (EMA) #8100

miraodasilva opened this issue Jun 23, 2021 · 6 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@miraodasilva
Copy link

🚀 Feature

Keep an Exponential Moving Average (EMA) of the model's weights as it is training. This is available on tensorflow but not on pytorch. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage .

Motivation

EMA has shown to be extremely benefitial for bootstrapping better models from scratch. For instance efficientnet (v1 and v2) benefits heavily from the usage of this method. It's also widely used in Self-supervised Learning.

Pitch

Basically what needs to be done is already layed out here https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488 . This code requires this package https://github.com/fadel/pytorch_ema .

@miraodasilva miraodasilva added feature Is an improvement or enhancement help wanted Open to be worked on labels Jun 23, 2021
@justusschock
Copy link
Member

I think having it as a callback would be nice.

@miraodasilva Are you willing to contribute this?

@miraodasilva
Copy link
Author

Sorry, I don't really have the time do it properly right now. However, I will start working with it on my own a bit, and perhaps I will contribute in the future. Thanks for responding!

@tchaton
Copy link
Contributor

tchaton commented Jun 24, 2021

Hey @miraodasilva,

It can be done using the Stochastic Weight Averaging Callback and replacing the avg_gn function there:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py#L44

Here is the mean average implementation:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py#L287

@miraodasilva
Copy link
Author

I see, hadn't seen that, thanks a lot!

@stale
Copy link

stale bot commented Jul 24, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jul 24, 2021
@stale stale bot closed this as completed Aug 1, 2021
@hal-314
Copy link

hal-314 commented Dec 10, 2021

Only for future readers, I don't think that the given solution is equivalent. avg_fn is called once per epoch while EMA updates happens every training step. I don't think that EMA can be implemented with SWA callback, See #10914 for code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

4 participants