Deduplicate solver regularization, logging, and local rates and decays #2518

Merged
merged 3 commits into from May 27, 2015
View
@@ -11,7 +11,7 @@ namespace caffe {
/**
* @brief An interface for classes that perform optimization on Net%s.
*
- * Requires implementation of ComputeUpdateValue to compute a parameter update
+ * Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
@@ -39,8 +39,8 @@ class Solver {
int iter() { return iter_; }
protected:
- // Get the update value for the current iteration.
- virtual void ComputeUpdateValue() = 0;
+ // Make and apply the update value for the current iteration.
+ virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
@@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> {
protected:
void PreSolve();
Dtype GetLearningRate();
- virtual void ComputeUpdateValue();
+ virtual void ApplyUpdate();
+ virtual void Regularize(int param_id);
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
@@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param_file) {}
protected:
- virtual void ComputeUpdateValue();
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};
@@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
protected:
- virtual void ComputeUpdateValue();
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
Oops, something went wrong.