Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdaDelta Solver (v3) #2782

Merged
merged 3 commits into from
Aug 11, 2015
Merged

AdaDelta Solver (v3) #2782

merged 3 commits into from
Aug 11, 2015

Conversation

matthiasplappert
Copy link
Contributor

Picked up @kevinbache's branch (#2204), merged it with master, resolved merge conflicts and fixed a couple of issues due to API changes. All tests pass.

However, I need input on one change, please see comment directly in the diff.

@@ -434,7 +434,8 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
rate = Dtype(0.);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure what the best way to solve this is. The problem here is that AdaDelta solver does not support a learning rate. However, since AdaDelta inherits from SGD, and SGD calls ApplyUpdates which, in turn, calls this method, we trigger the default case and therefore the fatal log (which is currently commented out). Returning a rate of 0.0 works fine, but is probably likely to cause errors in other areas of the code base where a valid learning rate is expected. Any input on this is greatly appreciated!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible idea: keep the learning rate schedule, and treat it as a multiplier on the AdaDelta update step size. The only ugly part of this solution is that it would require the user to specify base_lr: 1 lr_policy: 'fixed' in order to get the default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a possible solution. Before going any further with this, is adding AdaDelta even of interest for caffe? I don't want to invest time into this if it's not likely to land in master eventually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly argue for AdaDelta shipped within the Caffe-Framework. I was surprised that it isn't already in the master-branch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also strongly in favor of having AdaDelta in Caffe. I'll go over and review this PR today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the learning rate issue, I suggest using base_lr: 1 and lr_policy: 'fixed'

I suppose learning rate specification is still sometimes needed, even if you use AdaDelta. Take fine-tuning as an example, you may still want to have a smaller learning rate on pre-trained layers than on random-initialized layers even if you use AdaDelta.

For clarity, Let's change line 7 of Algorithm 1 in AdaDelta paper from:

x(t+1) = x(t) + delta_x(t)

to

x(t+1) = x(t) + local_rate * delta_x(t)

where local_rate = base_lr * lr_mult is the local learning rate for each parameter blob.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthiasplappert matthiasplappert changed the title AdaDelta v3 AdaDelta Solver (attempt number 3) Jul 18, 2015
@matthiasplappert matthiasplappert changed the title AdaDelta Solver (attempt number 3) AdaDelta Solver (v3) Jul 18, 2015
@matthiasplappert
Copy link
Contributor Author

Travis failed b/c of lint error (the commented-out LOG is causing the error, which will go away before merging this anyway, see comment above).

@shelhamer
Copy link
Member

@matthiasplappert thanks for making the update, but take another look at #2518 and see how the regularization and logging code was pulled out into SGDSolver.

explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you have something wrong here. Now you are calling PreSolve() in constructor of both AdaDeltaSolver and SGDSolver, and since you turned in into a virtual method, you are now calling AdaDeltaSolver::PreSolve() twice when constructing a AdaDeltaSolver instance. Is that the desired behavior?
Sorry I was wrong here. Before the derived class constructor is called, the dynamic type of the object under construction is a base class instance and not a derived class instance. For this reason, you are still calling AdaDeltaSolver::PreSolve() in AdaDeltaSolver::AdaDeltaSolver after calling SGDSolver::Presolve() in SGDSolver::SGDSolver. However, I still don't see a reason making Presolve a virtual function, and in general it is not good to call a virtual function inside a constructor in C++.

Also see the comment below in AdaDeltaSolver<Dtype>::PreSolve().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

virtual issue addressed in aedff90

@ronghanghu
Copy link
Member

@matthiasplappert Thanks for your great PR to introduce AdaDelta solver into Caffe!

The remaining work include:

  • Add learning rate.
  • Remove regularization.
  • Add more test cases.
  • Change back Presolve() to be non-virtual.

Please modify and update according to the reviews.

@matthiasplappert
Copy link
Contributor Author

@ronghanghu I'll try to find some time over the weekend to get all of this done. We should also thank @kevinbache and especially @mohomran (who wrote the original code), since I just carried on with their work.

@ronghanghu
Copy link
Member

#2836 and #2866 introduced new conflicts to be resolved.

@matthiasplappert
Copy link
Contributor Author

I'll resolve the conflict later today and (hopefully) address the reaming issues as well.

  • Add learning rate.
  • Remove regularization.
  • Add more test cases.
  • Change back PreSolve() to be non-virtual.

@matthiasplappert
Copy link
Contributor Author

Update on this: This branch is now up-to-date with master and all feedback has been addressed. The tests pass locally and I expect them to also pass on the CI.

Please review my changes and let me know if everything else is required on my end, e.g. cleaning up the commit history (not sure how you usually handle this). I've also pointed out the relevant commits in each feedback discussion to hopefully help with reviewing the changes.

Finally, I have one suggestion to make: having all solvers in one relatively big file (solver.cpp) proved to be a really big pain while resolving the merge conflicts. The problem there was that RMSProb and AdaDelta were completely mixed up since they share a lot of similar code. I would propose to eventually split out the individual solvers into separate files to avoid this in the future. Should I open an issue for that?

@ronghanghu
Copy link
Member

@matthiasplappert Thanks a lot for the update. I will review the changes today.

Finally, I have one suggestion to make: having all solvers in one relatively big file (solver.cpp) proved to be a really big pain while resolving the merge conflicts. The problem there was that RMSProb and AdaDelta were completely mixed up since they share a lot of similar code. I would propose to eventually split out the individual solvers into separate files to avoid this in the future. Should I open an issue for that?

Yes, this is quite a problem. I expect to send a solver refactor PR to split solver.cpp and extract common code for these adaptive gradient solvers, after merging AdaDelta and Adam (#2856).

: SGDSolver<Dtype>(param_file) { PreSolve(); }

protected:
void PreSolve();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to rename AdaDeltaSolver::PreSolve() into AdaDeltaSolver::AdaDeltaPreSolve(). Since you are going to call AdaDeltaSolver's presolve function after SGDSolver's presolve function, it is better to avoid a name conflict with SGDSolver::PreSolve(), no matter whether it is a virtual function.


template <typename Dtype>
void AdaDeltaSolver<Dtype>::Regularize(int param_id) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the entire AdaDeltaSolver::Regularize function.

The only difference between your AdaDeltaSolver::Regularize and the original SGDSolver::Regularize seem to be that you use const vector<shared_ptr<Blob<Dtype> > >& net_params rather than const vector<Blob<Dtype>*>& net_params. The rest are all the same.

Note that after #2866, one should use const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); to be consistent.

So, I believe we don't need a AdaDeltaSolver::Regularize here. Let's just use SGDSolver::Regularize instead.

@ronghanghu
Copy link
Member

@matthiasplappert I just made a few comments above. Let's get the following work done and I think this PR will be ready:

  • Rename AdaDelta::PreSolve into AdaDelta::AdaDeltaPreSolve.
  • Remove the AdaDelta::Regularize() function entirely.
  • Replace const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params(); with const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); to be consistent with Fix weight sharing #2866
  • Add 4 more test cases to be consistent with Snapshot model weights/solver state to HDF5 files #2836 and Fix weight sharing #2866.
  • After that, squash commits by each author into a single commit, and take a further rebase against bvlc/master.

@shelhamer
Copy link
Member

@matthiasplappert a note about history: instead of squashing to a single commit, please squash the commits by each author into a single commit. This will leave three commits by @mohomran @kevinbache and yourself. In future work please make use of rebase instead of merge, as our policy is to only have merge commits for PRs. Thanks.

having all solvers in one relatively big file (solver.cpp) proved to be a really big pain while resolving the merge conflicts. [...] I would propose to eventually split out the individual solvers into separate files to avoid this in the future.

Absolutely, and this was noted in #2860 but deserves another issue so I've transplanted it to #2890.

@matthiasplappert
Copy link
Contributor Author

@ronghanghu Thanks for the thorough review! I'm still very new to caffe, so your feedback is very much appreciated.

I've addressed the remaining feedback and cleaned up the commit history (also: no more merges). All tests pass locally (not sure if Travis will pick this up since the branch was force-pushed to override the history). Let me know if anything else needs to be done before we can land this in master.

@ronghanghu
Copy link
Member

@matthiasplappert Thanks for the update! I'll take a final review, and I expect to merge it tomorrow. @jeffdonahue could you also take a look?

@ronghanghu
Copy link
Member

Finished final review. Thanks for the @mohomran, @kevinbache and @matthiasplappert for this excellent AdaDelta solver.

ctrevino added a commit to Robotertechnik/caffe that referenced this pull request Aug 11, 2015
Merge pull request BVLC#2782 from matthiasplappert/adadelta
ctrevino added a commit to Robotertechnik/caffe that referenced this pull request Aug 11, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants