Adam solver #2918

Merged
merged 2 commits into from Aug 14, 2015
Jump to file or symbol
Failed to load files and symbols.
+307 −28
Split
@@ -0,0 +1,26 @@
+# The train/test net protocol buffer definition
+# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION"
+net: "examples/mnist/lenet_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# All parameters are from the cited paper above
+base_lr: 0.001
+momentum: 0.9
+momentum2: 0.999
+# since Adam dynamically changes the learning rate, we set the base learning
+# rate to a fixed value
+lr_policy: "fixed"
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "examples/mnist/lenet"
+# solver mode: CPU or GPU
+solver_type: ADAM
+solver_mode: GPU
@@ -0,0 +1,3 @@
+#!/usr/bin/env sh
+
+./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt
View
@@ -217,6 +217,29 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};
@jeffdonahue

jeffdonahue Aug 14, 2015

Contributor

We need to cite the ADAM paper somewhere*. I suggest putting a reference here, e.g. in a doxygen formatted comment like this. Eventually it would also be good to add sections to the solver tutorial on these new solvers, where the reference should also then be added.

*We probably also need to go back and add references for some of the other recently merged solvers.

+/**
+ * @brief AdamSolver, an algorithm for first-order gradient-based optimization
+ * of stochastic objective functions, based on adaptive estimates of
+ * lower-order moments. Described in [1].
+ *
+ * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
+ * arXiv preprint arXiv:1412.6980v8 (2014).
+ */
+template <typename Dtype>
+class AdamSolver : public SGDSolver<Dtype> {
+ public:
+ explicit AdamSolver(const SolverParameter& param)
+ : SGDSolver<Dtype>(param) { AdamPreSolve();}
+ explicit AdamSolver(const string& param_file)
+ : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
+
+ protected:
+ void AdamPreSolve();
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
+
+ DISABLE_COPY_AND_ASSIGN(AdamSolver);
+};
+
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
@@ -232,6 +255,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new RMSPropSolver<Dtype>(param);
case SolverParameter_SolverType_ADADELTA:
return new AdaDeltaSolver<Dtype>(param);
+ case SolverParameter_SolverType_ADAM:
+ return new AdamSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
@@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 39 (last added: rms_decay)
+// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
@@ -216,10 +216,13 @@ message SolverParameter {
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
+ ADAM = 5;
}
optional SolverType solver_type = 30 [default = SGD];
- // numerical stability for AdaGrad
+ // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [default = 1e-8];
+ // parameters for the Adam solver
+ optional float momentum2 = 39 [default = 0.999];
// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
View
@@ -1114,11 +1114,115 @@ void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}
+template <typename Dtype>
+void AdamSolver<Dtype>::AdamPreSolve() {
+ // Add the extra history entries for Adam after those from
+ // SGDSolver::PreSolve
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ for (int i = 0; i < net_params.size(); ++i) {
+ const vector<int>& shape = net_params[i]->shape();
+ this->history_.push_back(
+ shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ }
+}
+
+template <typename Dtype>
+void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ const Dtype beta1 = this->param_.momentum();
+ const Dtype beta2 = this->param_.momentum2();
+
+ // we create aliases for convenience
+ size_t update_history_offset = net_params.size();
+ Blob<Dtype>* val_m = this->history_[param_id].get();
+ Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
+ Blob<Dtype>* val_t = this->temp_[param_id].get();
+
+ const int t = this->iter_ + 1;
+ const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
+ (Dtype(1.) - pow(beta1, t));
+ const int N = net_params[param_id]->count();
+ const Dtype eps_hat = this->param_.delta();
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+ caffe_cpu_axpby(N, Dtype(1)-beta1,
+ net_params[param_id]->cpu_diff(), beta1,
+ val_m->mutable_cpu_data());
+
+ // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+ caffe_mul(N,
+ net_params[param_id]->cpu_diff(),
+ net_params[param_id]->cpu_diff(),
+ val_t->mutable_cpu_data());
+ caffe_cpu_axpby(N, Dtype(1)-beta2,
+ val_t->cpu_data(), beta2,
+ val_v->mutable_cpu_data());
+
+ // set update
+ caffe_powx(N,
+ val_v->cpu_data(), Dtype(0.5),
+ val_t->mutable_cpu_data());
+ caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
+ caffe_div(N,
+ val_m->cpu_data(),
+ val_t->cpu_data(),
+ val_t->mutable_cpu_data());
+
+ caffe_cpu_scale(N, local_rate*correction,
+ val_t->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+ caffe_gpu_axpby(N, Dtype(1)-beta1,
+ net_params[param_id]->gpu_diff(), beta1,
+ val_m->mutable_gpu_data());
+
+ // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+ caffe_gpu_mul(N,
+ net_params[param_id]->gpu_diff(),
+ net_params[param_id]->gpu_diff(),
+ val_t->mutable_gpu_data());
+ caffe_gpu_axpby(N, Dtype(1)-beta2,
+ val_t->gpu_data(), beta2,
+ val_v->mutable_gpu_data());
+
+ // set update
+ caffe_gpu_powx(N,
+ val_v->gpu_data(), Dtype(0.5),
+ val_t->mutable_gpu_data());
+ caffe_gpu_add_scalar(N, eps_hat,
+ val_t->mutable_gpu_data());
+ caffe_gpu_div(N,
+ val_m->gpu_data(),
+ val_t->gpu_data(),
+ val_t->mutable_gpu_data());
+
+ caffe_gpu_scale(N, local_rate*correction,
+ val_t->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSPropSolver);
INSTANTIATE_CLASS(AdaDeltaSolver);
+INSTANTIATE_CLASS(AdamSolver);
} // namespace caffe
Oops, something went wrong.