diff --git a/flashlight/optim/AdamOptimizer.cpp b/flashlight/optim/AdamOptimizer.cpp index abe877c02..ba49c093c 100644 --- a/flashlight/optim/AdamOptimizer.cpp +++ b/flashlight/optim/AdamOptimizer.cpp @@ -52,6 +52,11 @@ AdamOptimizer::AdamOptimizer( } void AdamOptimizer::step() { + count_++; + float correctedBias1 = 1 - std::pow(beta1_, count_); + float correctedBias2 = 1 - std::pow(beta2_, count_); + float correctedLr = lr_ * std::sqrt(correctedBias2) / correctedBias1; + for (size_t i = 0; i < parameters_.size(); i++) { if (!parameters_[i].isGradAvailable()) { continue; @@ -74,12 +79,6 @@ void AdamOptimizer::step() { af::eval(biasedFirst); af::eval(biasedSecond); - count_++; - - float correctedBias1 = 1 - std::pow(beta1_, count_); - float correctedBias2 = 1 - std::pow(beta2_, count_); - float correctedLr = lr_ * std::sqrt(correctedBias2) / correctedBias1; - data = data - (correctedLr * biasedFirst) / (af::sqrt(biasedSecond) + eps_); af::eval(data);