|
|
@@ -11,7 +11,7 @@ namespace caffe { |
|
|
/**
|
|
|
* @brief An interface for classes that perform optimization on Net%s.
|
|
|
*
|
|
|
- * Requires implementation of ComputeUpdateValue to compute a parameter update
|
|
|
+ * Requires implementation of ApplyUpdate to compute a parameter update
|
|
|
* given the current state of the Net parameters.
|
|
|
*/
|
|
|
template <typename Dtype>
|
|
|
@@ -39,8 +39,8 @@ class Solver { |
|
|
int iter() { return iter_; }
|
|
|
|
|
|
protected:
|
|
|
- // Get the update value for the current iteration.
|
|
|
- virtual void ComputeUpdateValue() = 0;
|
|
|
+ // Make and apply the update value for the current iteration.
|
|
|
+ virtual void ApplyUpdate() = 0;
|
|
|
// The Solver::Snapshot function implements the basic snapshotting utility
|
|
|
// that stores the learned net. You should implement the SnapshotSolverState()
|
|
|
// function that produces a SolverState protocol buffer that needs to be
|
|
|
@@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> { |
|
|
protected:
|
|
|
void PreSolve();
|
|
|
Dtype GetLearningRate();
|
|
|
- virtual void ComputeUpdateValue();
|
|
|
+ virtual void ApplyUpdate();
|
|
|
+ virtual void Regularize(int param_id);
|
|
|
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
|
|
virtual void ClipGradients();
|
|
|
virtual void SnapshotSolverState(SolverState * state);
|
|
|
virtual void RestoreSolverState(const SolverState& state);
|
|
|
@@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> { |
|
|
: SGDSolver<Dtype>(param_file) {}
|
|
|
|
|
|
protected:
|
|
|
- virtual void ComputeUpdateValue();
|
|
|
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
|
|
|
};
|
|
|
@@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> { |
|
|
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
|
|
|
|
|
|
protected:
|
|
|
- virtual void ComputeUpdateValue();
|
|
|
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
|
|
|
void constructor_sanity_check() {
|
|
|
CHECK_EQ(0, this->param_.momentum())
|
|
|
<< "Momentum cannot be used with AdaGrad.";
|
|
|
|