Skip to content

Commit

Permalink
feat(solver): calculation of learning rate
Browse files Browse the repository at this point in the history
from config

In Caffe the learning rate is calculated by the solver;
I moved the calculation into the SolverConfig because it only depends on the
current variation of the solver (which can easily be supplied) and on a lot of
config variables.
  • Loading branch information
hobofan committed Nov 4, 2015
1 parent 38c7cfb commit edfbbdf
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 167 deletions.
140 changes: 133 additions & 7 deletions src/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<'a, S: ISolver> Solver<'a, S>{
// if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
// LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
// }
self.worker.apply_update(&self.param, &mut self.net);
self.worker.apply_update(&self.param, &mut self.net, self.iter);

// Increment the internal iter counter -- its value should always indicate
// the number of times the weights have been updated.
Expand Down Expand Up @@ -205,7 +205,7 @@ impl<'a, S: ISolver> Solver<'a, S>{
/// Implementation of a specific Solver.
pub trait ISolver {
/// TODO: what does this do?
fn apply_update(&self, param: &SolverConfig, network: &mut Network);
fn apply_update(&self, param: &SolverConfig, network: &mut Network, iter: usize);
}

#[derive(Debug)]
Expand All @@ -214,24 +214,40 @@ pub struct SolverConfig {
/// Name of the solver.
pub name: String,
/// The `NetworkConfig` that is used to initialize the training network.
train_net: NetworkConfig,
pub train_net: NetworkConfig,
/// Display the loss averaged over the last average_loss iterations.
///
/// Default: 1
average_loss: usize,
pub average_loss: usize,
/// The number of iterations between two testing phases.
///
/// Default: None
test_interval: Option<usize>,
pub test_interval: Option<usize>,
/// If true, run an initial test pass before the first iteration,
/// ensuring memory availability and printing the starting value of the loss.
///
/// Default: true
test_initialization: bool,
pub test_initialization: bool,
/// Accumulate gradients over minibatch_size instances.
///
/// Default: 1
minibatch_size: usize,
pub minibatch_size: usize,
/// The learning rate policy to be used.
///
/// Default: Fixed
pub lr_policy: LRPolicy,
/// The base learning rate.
///
/// Default: 0.01
pub base_lr: f32,
/// gamma as used in the calculation of most learning rate policies.
///
/// Default: 0.1
pub gamma: f32,
/// The stepsize used in Step and Sigmoid learning policies.
///
/// Default: 10
pub stepsize: usize,
}

impl Default for SolverConfig {
Expand All @@ -244,6 +260,12 @@ impl Default for SolverConfig {
test_interval: None,
test_initialization: true,
minibatch_size: 1,

lr_policy: LRPolicy::Fixed,
base_lr: 0.01f32,
gamma: 0.1f32,

stepsize: 10,
}
}
}
Expand All @@ -253,4 +275,108 @@ impl SolverConfig {
pub fn test_interval(&self) -> usize {
self.test_interval.unwrap_or(0)
}

/// Return the current learning rate. The currently implemented learning rate
/// policies are as follows:
/// - fixed: always return base_lr.
/// - step: return base_lr * gamma ^ (floor(iter / step))
/// - exp: return base_lr * gamma ^ iter
/// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
/// - multistep: similar to step but it allows non uniform steps defined by
/// stepvalue
/// - poly: the effective learning rate follows a polynomial decay, to be
/// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
/// - sigmoid: the effective learning rate follows a sigmod decay
/// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
///
/// where base_lr, max_iter, gamma, step, stepvalue and power are defined
/// in the solver config, and iter is the current iteration.
pub fn get_learning_rate(&self, iter: usize) -> f32 {
match self.lr_policy() {
LRPolicy::Fixed => {
self.base_lr()
}
LRPolicy::Step => {
let current_step = iter / self.stepsize();
self.base_lr() * self.gamma().powf(current_step as f32)
}
LRPolicy::Multistep => {
// if (this->current_step_ < this->param_.stepvalue_size() &&
// this->iter_ >= this->param_.stepvalue(this->current_step_)) {
// this->current_step_++;
// LOG(INFO) << "MultiStep Status: Iteration " <<
// this->iter_ << ", step = " << this->current_step_;
// }
// rate = this->param_.base_lr() *
// pow(this->param_.gamma(), this->current_step_);
unimplemented!();
}
LRPolicy::Exp => {
self.base_lr() * self.gamma().powf(iter as f32)
}
LRPolicy::Inv => {
// rate = this->param_.base_lr() *
// pow(Dtype(1) + this->param_.gamma() * this->iter_,
// - this->param_.power());
unimplemented!();
}
LRPolicy::Poly => {
// rate = this->param_.base_lr() * pow(Dtype(1.) -
// (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
// this->param_.power());
unimplemented!();
}
LRPolicy::Sigmoid => {
// rate = this->param_.base_lr() * (Dtype(1.) /
// (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
// Dtype(this->param_.stepsize())))));
unimplemented!();
}
}
}

/// Return learning rate policy.
fn lr_policy(&self) -> LRPolicy {
self.lr_policy
}

/// Return the base learning rate.
fn base_lr(&self) -> f32 {
self.base_lr
}

/// Return the gamma for learning rate calculations.
fn gamma(&self) -> f32 {
self.gamma
}

/// Return the stepsize for learning rate calculations.
fn stepsize(&self) -> usize {
self.stepsize
}
}


#[derive(Debug, Copy, Clone)]
/// Learning Rate Policy for a Solver
pub enum LRPolicy {
/// always return base_lr
Fixed,
/// learning rate decays every `step` iterations.
/// return base_lr * gamma ^ (floor(iter / step))
Step,
/// similar to step but it allows non uniform steps defined by
/// stepvalue
Multistep,
/// return base_lr * gamma ^ iter
Exp,
/// return base_lr * (1 + gamma * iter) ^ (- power)
Inv,
/// the effective learning rate follows a polynomial decay, to be
/// zero by the max_iter.
/// return base_lr (1 - iter/max_iter) ^ (power)
Poly,
/// the effective learning rate follows a sigmod decay
/// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
Sigmoid,
}
4 changes: 0 additions & 4 deletions src/solvers/mod.rs.bk

This file was deleted.

58 changes: 2 additions & 56 deletions src/solvers/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,6 @@ use network::Network;
pub struct SGD;

impl SGD {
/// Return the current learning rate. The currently implemented learning rate
/// policies are as follows:
/// - fixed: always return base_lr.
/// - step: return base_lr * gamma ^ (floor(iter / step))
/// - exp: return base_lr * gamma ^ iter
/// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
/// - multistep: similar to step but it allows non uniform steps defined by
/// stepvalue
/// - poly: the effective learning rate follows a polynomial decay, to be
/// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
/// - sigmoid: the effective learning rate follows a sigmod decay
/// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
///
/// where base_lr, max_iter, gamma, step, stepvalue and power are defined
/// in the solver parameter protocol buffer, and iter is the current iteration.
fn get_learning_rate(&self) -> f32 {
// Dtype rate;
// const string& lr_policy = this->param_.lr_policy();
// if (lr_policy == "fixed") {
// rate = this->param_.base_lr();
// } else if (lr_policy == "step") {
// this->current_step_ = this->iter_ / this->param_.stepsize();
// rate = this->param_.base_lr() *
// pow(this->param_.gamma(), this->current_step_);
// } else if (lr_policy == "exp") {
// rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
// } else if (lr_policy == "inv") {
// rate = this->param_.base_lr() *
// pow(Dtype(1) + this->param_.gamma() * this->iter_,
// - this->param_.power());
// } else if (lr_policy == "multistep") {
// if (this->current_step_ < this->param_.stepvalue_size() &&
// this->iter_ >= this->param_.stepvalue(this->current_step_)) {
// this->current_step_++;
// LOG(INFO) << "MultiStep Status: Iteration " <<
// this->iter_ << ", step = " << this->current_step_;
// }
// rate = this->param_.base_lr() *
// pow(this->param_.gamma(), this->current_step_);
// } else if (lr_policy == "poly") {
// rate = this->param_.base_lr() * pow(Dtype(1.) -
// (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
// this->param_.power());
// } else if (lr_policy == "sigmoid") {
// rate = this->param_.base_lr() * (Dtype(1.) /
// (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
// Dtype(this->param_.stepsize())))));
// } else {
// LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
// }
// return rate;
unimplemented!();
}

fn clip_gradients(&self) {
// const Dtype clip_gradients = this->param_.clip_gradients();
// if (clip_gradients < 0) { return; }
Expand All @@ -82,9 +28,9 @@ impl SGD {
}

impl ISolver for SGD {
fn apply_update(&self, param: &SolverConfig, net: &mut Network) {
fn apply_update(&self, param: &SolverConfig, net: &mut Network, iter: usize) {
// CHECK(Caffe::root_solver()); // Caffe
let rate = self.get_learning_rate();
let rate = param.get_learning_rate(iter);

self.clip_gradients();
for (param_id, param) in net.learnable_params().iter().enumerate() {
Expand Down
100 changes: 0 additions & 100 deletions src/solvers/sgd.rs.bk

This file was deleted.

Loading

0 comments on commit edfbbdf

Please sign in to comment.