Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some minor optimizations #483

Merged
merged 2 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ where
checkpoint: None,
timeout: None,
ctrlc: true,
timer: true,
timer: false,
}
}

Expand Down Expand Up @@ -383,9 +383,9 @@ where
self
}

/// Enables or disables timing of individual iterations (default: enabled).
/// Enables or disables timing of individual iterations (default: false).
///
/// Setting this to false will silently be ignored in case a timeout is set.
/// In case a timeout is set, this will automatically be set to true.
///
/// # Example
///
Expand Down Expand Up @@ -768,7 +768,7 @@ mod tests {
let problem = TestProblem::new();
let timeout = std::time::Duration::from_secs(2);

let executor = Executor::new(problem, solver);
let executor = Executor::new(problem, solver).timer(true);
assert!(executor.timer);
assert!(executor.timeout.is_none());

Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ pub struct IterState<P, G, J, H, R, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -969,6 +971,20 @@ where
pub fn take_prev_residuals(&mut self) -> Option<R> {
self.prev_residuals.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{IterState, State};
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>
Expand Down Expand Up @@ -1039,6 +1055,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -1338,7 +1355,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, IterState, State, ArgminFloat};
/// # let mut state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
/// # let mut state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -1355,9 +1372,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, have you been able to find out whether hashing is the problem or the allocation of .to_string()? In case it was the latter, I was wondering if Cow would help?
If it's hashing, I was wondering if replacing HashMap<String, u64> with HashMap<SomeEnum, u64> would help, where

enum SomeEnum {
    CostCount,
    GradientCount,
    OperatorCount,
    JacobianCount,
    HessianCount,
    AnnealCount,
    Other(String),
}

(or something along those lines). That would cover all the function calls that in argmin, but also allows one to count function calls defined in external code via Other. It also avoids .to_string().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember it's 35% allocation and 65% hashing. HashMap is DoS resistant by default. Changing the hash function or switching to BTreeMap combined with enum approach you propose might help a bit I guess, but this changes public API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind changing the public API in that case. Why do you propose a BTreeMap instead of a HashMap? Hashing an enum should be quick (I expect it to use it's integer representation as a hash with a bit of additional overhead for hashing Other(String)). Anyway I think this goes beyond this PR and can certainly be left as a future improvement.

*count = v
}
}
}

Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/linearprogramstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub struct LinearProgramState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -150,6 +152,20 @@ impl<P, F> LinearProgramState<P, F> {
self.cost = cost;
self
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, LinearProgramState};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for LinearProgramState<P, F>
Expand Down Expand Up @@ -205,6 +221,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -503,7 +520,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -520,9 +537,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
27 changes: 23 additions & 4 deletions crates/argmin/src/core/state/populationstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct PopulationState<P, F> {
pub max_iters: u64,
/// Evaluation counts
pub counts: HashMap<String, u64>,
/// Update evaluation counts?
pub counting_enabled: bool,
/// Time required so far
pub time: Option<instant::Duration>,
/// Status of optimization execution
Expand Down Expand Up @@ -429,6 +431,20 @@ where
pub fn take_population(&mut self) -> Option<Vec<P>> {
self.population.take()
}

/// Overrides state of counting function executions (default: false)
/// ```
/// # use argmin::core::{State, PopulationState};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # assert!(!state.counting_enabled);
/// let state = state.counting(true);
/// # assert!(state.counting_enabled);
/// ```
#[must_use]
pub fn counting(mut self, mode: bool) -> Self {
self.counting_enabled = mode;
self
}
}

impl<P, F> State for PopulationState<P, F>
Expand Down Expand Up @@ -483,6 +499,7 @@ where
last_best_iter: 0,
max_iters: std::u64::MAX,
counts: HashMap::new(),
counting_enabled: false,
time: Some(instant::Duration::new(0, 0)),
termination_status: TerminationStatus::NotTerminated,
}
Expand Down Expand Up @@ -782,7 +799,7 @@ where
/// ```
/// # use std::collections::HashMap;
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
/// # assert_eq!(state.counts, HashMap::new());
/// # state.counts.insert("test2".to_string(), 10u64);
/// #
Expand All @@ -799,9 +816,11 @@ where
/// # assert_eq!(state.counts, hm);
/// ```
fn func_counts<O>(&mut self, problem: &Problem<O>) {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
if self.counting_enabled {
for (k, &v) in problem.counts.iter() {
let count = self.counts.entry(k.to_string()).or_insert(0);
*count = v
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ mod tests {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
.configure(|state| state.max_iters(13))
.configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
Expand Down
14 changes: 12 additions & 2 deletions crates/argmin/src/solver/linesearch/backtracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down Expand Up @@ -689,7 +694,12 @@ mod tests {
ls.search_direction(vec![2.0f64, 0.0]);

let data = Executor::new(prob, ls.clone())
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
.configure(|config| {
config
.counting(true)
.param(ls.init_param.clone().unwrap())
.max_iters(10)
})
.run();
assert!(data.is_ok());

Expand Down
7 changes: 6 additions & 1 deletion crates/argmin/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ fn test_lbfgs_func_count() {
let linesearch = MoreThuenteLineSearch::new();
let solver = LBFGS::new(linesearch, 10);
let res = Executor::new(cost.clone(), solver)
.configure(|config| config.param(cost.param_init.clone()).max_iters(100))
.configure(|config| {
config
.param(cost.param_init.clone())
.max_iters(100)
.counting(true)
})
.run()
.unwrap();

Expand Down
Loading