Skip to content

Commit

Permalink
Disable counting by default
Browse files Browse the repository at this point in the history
  • Loading branch information
pacak authored and stefan-k committed Mar 30, 2024
1 parent 6492020 commit a9e3eb7
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 12 deletions.
14 changes: 14 additions & 0 deletions crates/argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,20 @@ where
pub fn take_prev_residuals(&mut self) -> Option<R> {
self.prev_residuals.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{IterState, State};
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>
Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/linearprogramstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub struct LinearProgramState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -150,6 +152,20 @@ impl<P, F> LinearProgramState<P, F> {
self.cost = cost;
self
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, LinearProgramState};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for LinearProgramState<P, F>
Expand Down Expand Up @@ -205,6 +221,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -503,7 +520,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -520,9 +537,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/populationstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct PopulationState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -429,6 +431,20 @@ where
pub fn take_population(&mut self) -> Option<Vec<P>> {
self.population.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, PopulationState};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for PopulationState<P, F>
Expand Down Expand Up @@ -483,6 +499,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -782,7 +799,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -799,9 +816,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ mod tests {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(13))
.configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
Expand Down
14 changes: 12 additions & 2 deletions crates/argmin/src/solver/linesearch/backtracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down Expand Up @@ -689,7 +694,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down
7 changes: 6 additions & 1 deletion crates/argmin/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ fn test_lbfgs_func_count() {
let linesearch = MoreThuenteLineSearch::new();
let solver = LBFGS::new(linesearch, 10);
let res = Executor::new(cost.clone(), solver)
.configure(|config| config.param(cost.param_init.clone()).max_iters(100))
.configure(|config| {
config
.param(cost.param_init.clone())
.max_iters(100)
.counting(true)
})
.run()
.unwrap();

Expand Down

0 comments on commit a9e3eb7

Please sign in to comment.