Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdaMax solver #6263

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/caffe/sgd_solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ class AdamSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdamSolver);
};

/**
* @brief AdaMaxSolver, extension of Adam based on infinity norm.
*/

template <typename Dtype>
class AdaMaxSolver : public AdamSolver <Dtype> {
public:
explicit AdaMaxSolver(const SolverParameter& param)
: AdamSolver<Dtype>(param) { }
explicit AdaMaxSolver(const string& param_file)
: AdamSolver<Dtype>(param_file) { }
virtual inline const char* type() const { return "AdaMax"; }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdaMaxSolver);
};

} // namespace caffe

#endif // CAFFE_SGD_SOLVERS_HPP_
77 changes: 77 additions & 0 deletions src/caffe/solvers/adamax_solver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <algorithm>
#include <vector>

#include "caffe/sgd_solvers.hpp"

namespace caffe {

#ifndef CPU_ONLY
template <typename Dtype>
void adamax_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate);
#endif

template <typename Dtype>
void AdaMaxSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype local_rate = rate * net_params_lr[param_id];
const Dtype beta1 = this->param_.momentum();
const Dtype beta2 = this->param_.momentum2();

// we create aliases for convenience
size_t update_history_offset = net_params.size();
Blob<Dtype>* val_m = this->history_[param_id].get();
Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
Blob<Dtype>* val_t = this->temp_[param_id].get();

const int t = this->iter_ + 1;
const Dtype correction = Dtype(1) / (Dtype(1) - pow(beta1, t));
const int N = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();

switch (Caffe::mode()) {
case Caffe::CPU: {
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_cpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->cpu_diff(), beta1,
val_m->mutable_cpu_data());

// update v <- max(\beta_2 v_{t-1}, |g_t|)
// for stability, add a small epsilon to \beta_2 v_{t-1}
caffe_abs(N, net_params[param_id]->cpu_diff(), val_t->mutable_cpu_data());
for (int i = 0; i < N; ++i) {
val_v->mutable_cpu_data()[i] = std::max(
val_v->cpu_data()[i] * beta2 + eps_hat,
val_t->cpu_data()[i]);
}

// set update
caffe_div(N,
val_m->cpu_data(),
val_v->cpu_data(),
val_t->mutable_cpu_data());
caffe_cpu_scale(N, local_rate*correction, val_t->cpu_data(),
net_params[param_id]->mutable_cpu_diff());

break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
adamax_update_gpu(N, net_params[param_id]->mutable_gpu_diff(),
val_m->mutable_gpu_data(), val_v->mutable_gpu_data(), beta1, beta2,
eps_hat, local_rate*correction);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(AdaMaxSolver);
REGISTER_SOLVER_CLASS(AdaMax);

} // namespace caffe
31 changes: 31 additions & 0 deletions src/caffe/solvers/adamax_solver.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <algorithm>

#include "caffe/util/math_functions.hpp"


namespace caffe {

template <typename Dtype>
__global__ void AdaMaxUpdate(int N, Dtype* g, Dtype* m, Dtype* v,
Dtype beta1, Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
CUDA_KERNEL_LOOP(i, N) {
float gi = g[i];
float mi = m[i] = m[i]*beta1 + gi*(1-beta1);
float vi = v[i] = max(v[i]*beta2 + eps_hat, abs(gi));
g[i] = corrected_local_rate * mi / vi;
}
}
template <typename Dtype>
void adamax_update_gpu(int N, Dtype* g, Dtype* m, Dtype* v, Dtype beta1,
Dtype beta2, Dtype eps_hat, Dtype corrected_local_rate) {
AdaMaxUpdate<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
N, g, m, v, beta1, beta2, eps_hat, corrected_local_rate);
CUDA_POST_KERNEL_CHECK;
}
template void adamax_update_gpu<float>(int, float*, float*, float*,
float, float, float, float);
template void adamax_update_gpu<double>(int, double*, double*, double*,
double, double, double, double);

} // namespace caffe
115 changes: 114 additions & 1 deletion src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// Finally, compute update.
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
if (solver_->type() != string("AdaDelta")
&& solver_->type() != string("Adam")) {
&& solver_->type() != string("Adam")
&& solver_->type() != string("AdaMax")) {
ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
} else {
ASSERT_EQ(4, history.size()); // additional blobs for update history
Expand Down Expand Up @@ -336,6 +337,16 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
std::sqrt(Dtype(1) - pow(momentum2, num_iters)) /
(Dtype(1.) - pow(momentum, num_iters));
update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
} else if (solver_->type() == string("AdaMax")) {
const Dtype momentum2 = 0.999;
const Dtype m = history_value;
const Dtype v = (i == D) ?
history[1 + num_param_blobs]->cpu_data()[0] :
history[0 + num_param_blobs]->cpu_data()[i];
const Dtype val_m = (1 - momentum) * grad + momentum * m;
const Dtype val_v = std::max(momentum2 * v, std::abs(grad));
Dtype alpha_t = learning_rate / (Dtype(1) - pow(momentum, num_iters));
update_value = alpha_t * val_m / val_v;
} else {
LOG(FATAL) << "Unknown solver type: " << solver_->type();
}
Expand Down Expand Up @@ -1286,4 +1297,106 @@ TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
}
}

template <typename TypeParam>
class AdaMaxSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;

protected:
virtual void InitSolver(const SolverParameter& param) {
SolverParameter new_param = param;
const Dtype momentum = 0.9;
new_param.set_momentum(momentum);
const Dtype momentum2 = 0.999;
new_param.set_momentum2(momentum2);
this->solver_.reset(new AdaMaxSolver<Dtype>(new_param));
}
};

TYPED_TEST_CASE(AdaMaxSolverTest, TestDtypesAndDevices);

TYPED_TEST(AdaMaxSolverTest, TestAdaMaxLeastSquaresUpdate) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0;
const Dtype kMomentum = 0.9;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
}

TYPED_TEST(AdaMaxSolverTest, TestAdaMaxLeastSquaresUpdateWithWeightDecay) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum);
}

TYPED_TEST(AdaMaxSolverTest, TestAdaMaxLeastSquaresUpdateWithEverything) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(AdaMaxSolverTest, TestAdaMaxLeastSquaresUpdateWithEverythingShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(AdaMaxSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
const int kIterSize = 2;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

TYPED_TEST(AdaMaxSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
kIterSize);
}

TYPED_TEST(AdaMaxSolverTest, TestSnapshot) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

TYPED_TEST(AdaMaxSolverTest, TestSnapshotShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
}
}

} // namespace caffe