RMSprop clean up and rebase #2867

Merged
merged 1 commit into from Aug 9, 2015

Conversation

Projects
None yet
4 participants
Member

ronghanghu commented Aug 6, 2015

Rebased and adapted RMSprop implementation #1890 to the new solver interface #2518 and #1977. The original author is @erogol. Pulled against master instead of dev.

The RMSprop solver is based on G. Hinton's lecture (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). Param gradients are divided by average root mean square of gradients in recent batches. It can be seen as a mini-batch version of using only the sign of gradients.

Update rule:

MeanSquare(t) = rms_decay * MeanSquare(t-1) + (1 - rms_decay) * gradient(t)^2
param_update(t) = gradient(t) / (sqrt(MeanSquare(t)) + delta)

Momentum is not supported for RMSprop solver, as in #1890.

@shelhamer shelhamer added focus JD labels Aug 6, 2015

Contributor

erogol commented Aug 7, 2015

thanks for handling this :)

@jeffdonahue jeffdonahue commented on an outdated diff Aug 7, 2015

src/caffe/test/test_gradient_based_solver.cpp
@@ -521,7 +531,7 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
const Dtype kMomentum = 0.5;
const int kNumIters = 1;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, 0., i);
@jeffdonahue

jeffdonahue Aug 7, 2015

Contributor

These should be declared as constants (e.g. const Dtype kRMSDecay = 0) like the other args to make the meaning clear.

@jeffdonahue jeffdonahue commented on an outdated diff Aug 7, 2015

include/caffe/solver.hpp
@@ -128,6 +128,29 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
+
+template <typename Dtype>
+class RMSpropSolver : public SGDSolver<Dtype> {
@jeffdonahue

jeffdonahue Aug 7, 2015

Contributor

The capitalization (everywhere) should be RMSProp rather than RMSprop. (I think a global find/replace should do the trick.)

Contributor

jeffdonahue commented Aug 7, 2015

Thanks @erogol for the original work and thanks @ronghanghu for the rebase. This looks good except as noted above.

Member

ronghanghu commented Aug 7, 2015

@jeffdonahue OK, I'll handle them. Thanks for the comments!

Member

ronghanghu commented Aug 7, 2015

Fixed those issues. I expect this PR to be merged after #2856 and #2782.

Contributor

jeffdonahue commented Aug 7, 2015

Cool, LGTM. @ronghanghu feel free to merge whenever it's easiest for you, before or after the other two PRs.

@erogol @ronghanghu erogol Implement RMSProp Solver
Implement RMSProp solver and cleaned up to adjust to new solver interface that uses
accumulated gradients and refactored regularization.
abe99e8
Member

ronghanghu commented Aug 9, 2015

Took a further rebase on #2866. Authorship preserved for @erogol in commit

Ready to merge.

@ronghanghu ronghanghu added a commit that referenced this pull request Aug 9, 2015

@ronghanghu ronghanghu Merge pull request #2867 from ronghanghu/rms-prop
RMSProp clean up and rebase
698fc76

@ronghanghu ronghanghu merged commit 698fc76 into BVLC:master Aug 9, 2015

1 check passed

continuous-integration/travis-ci/pr The Travis CI build passed
Details

ronghanghu deleted the ronghanghu:rms-prop branch Aug 9, 2015

@shelhamer shelhamer commented on the diff Aug 9, 2015

src/caffe/test/test_gradient_based_solver.cpp
@@ -867,10 +906,124 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
+ const int kNumIters = 4;
+ this->share_ = true;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
+ }
+}
+
+template <typename TypeParam>
+class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+ virtual void InitSolver(const SolverParameter& param) {
+ this->solver_.reset(new RMSPropSolver<Dtype>(param));
@shelhamer

shelhamer Aug 9, 2015

Owner

Could you set the RMS decay here, instead of introducing the decay argument to least squares and snapshotting tests? Since it is unique to this solver I think it is best handled here.

@shelhamer shelhamer commented on the diff Aug 9, 2015

src/caffe/test/test_gradient_based_solver.cpp
@@ -173,6 +174,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
if (momentum != 0) {
proto << "momentum: " << momentum << " ";
}
+ if (rms_decay != 0) {
@shelhamer

shelhamer Aug 9, 2015

Owner

I think this could be handled in RMSPropSolverTest::InitSolver() since this detail is particular to the RMSProp solver alone.

Owner

shelhamer commented Aug 9, 2015

@ronghanghu Sorry I didn't catch this earlier, but I have a suggestion for the RMS decay parameter in the tests. Instead of introducing another argument and setting it for every test, this param could be set by the RMSProp test class for encapsulation. Could you send a follow-up PR to make this change?

ronghanghu restored the ronghanghu:rms-prop branch Aug 9, 2015

Member

ronghanghu commented Aug 9, 2015

@shelhamer Yes, I can send another PR to do that. Adam solver is also going to introduce a momentum2 parameter, which can be handle in the same way (put into InitSolver()).

Addressed in #2888.

ronghanghu deleted the ronghanghu:rms-prop branch Aug 9, 2015

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment