-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Adam Optimizer #28
Added Adam Optimizer #28
Conversation
* Update mod.rs * Update mod.rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a small question about an import
src/optimizers/adam.rs
Outdated
grad += param.value() * self.weight_decay; | ||
|
||
exp_avgs[i] = (beta1 * exp_avgs[i]) + ((1. - beta1) * grad); | ||
exp_avg_sqs[i] = (beta2 * exp_avg_sqs[i]) + ((1. - beta2) * grad.pow(2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you had to use num_traits::Pow
because of the grad.pow(2)
but you can do without the import by doing grad.powf(2.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right, but the reason I did it is because RMSProp
does use num_traits::Pow
as well. I can go ahead and make the change to both, though!
* change pow to powf * Update adam.rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome job!
Added
Adam
struct that implements theOptimizer
trait, closely following its PyTorch implementation. It essentially combines the two types of momentums found inSGD
andRMSProp
(with respective decay rates β1 and β2) by keeping track of several gradient-based averages.Source: https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam