Skip to content

Commit

Permalink
Switches CMA-ES to bulk_cost
Browse files Browse the repository at this point in the history
  • Loading branch information
VolodymyrOrlov authored and stefan-k committed Jul 24, 2022
1 parent 62d6e7a commit b819f00
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
22 changes: 10 additions & 12 deletions argmin/src/solver/cma_es/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
//! For details see [`CMAES`].

use crate::core::{
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, KV,
ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, SyncAlias,
KV,
};
use argmin_math::{
ArgminAdd, ArgminArgsort, ArgminAxisIter, ArgminBroadcast, ArgminDiv, ArgminDot,
Expand Down Expand Up @@ -45,10 +46,10 @@ use std::ops::{AddAssign, MulAssign};
/// }
///
/// impl CostFunction for Rosenbrock {
/// type Param = Vec<f32>;
/// type Output = f32;
/// type Param = Vec<f32>;
/// type Output = f32;
///
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
/// fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
/// Ok(rosenbrock_2d(p, self.a, self.b))
/// }
/// }
Expand Down Expand Up @@ -203,11 +204,12 @@ where

impl<O, P, F> Solver<O, PopulationState<P, F, P::Array2D>> for CMAES<P, F>
where
O: CostFunction<Param = P, Output = F>,
O: CostFunction<Param = P, Output = F> + SyncAlias,
Vec<F>: ArgminArgsort,
F: ArgminFloat + MulAssign + AddAssign + NumCast + ArgminDiv<P, P>,
P: SerializeAlias
+ Clone
+ SyncAlias
+ ArgminTransition
+ ArgminSize<usize>
+ ArgminZeroLike
Expand Down Expand Up @@ -251,12 +253,9 @@ where

state.population = Some(self.generate());

let fitness: Vec<F> = state
.get_population()
.unwrap()
.row_iterator()
.map(|p| problem.cost(&p).unwrap())
.collect();
let fitness: Vec<F> = problem
.bulk_cost(&state.get_population().unwrap().row_iterator().collect())
.unwrap();

let fitness_indices = fitness.argsort();

Expand Down Expand Up @@ -411,7 +410,6 @@ mod tests {
assert!(state.best_individual.is_some());

let solution = state.best_individual.unwrap();
println!("{:?}", solution);
assert!((solution[0] - 1.0).abs() <= precision);
assert!((solution[1] - 1.0).abs() <= precision);
}
Expand Down
4 changes: 0 additions & 4 deletions media/book/src/concept.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@ There are three components needed for solving an optimization problem in argmin:
The [Executor](https://docs.rs/argmin/latest/argmin/core/struct.Executor.html) applies the solver to the optimization problem.
It also accepts observers and checkpointing mechanisms, as well as an initial guess of the parameter vector, the cost function value at that initial guess, gradient, and so on.

<<<<<<< HEAD:media/book/src/concept.md
A solver is anything that implements the [Solver](https://docs.rs/argmin/latest/argmin/core/trait.Solver.html) trait.
This trait defines how the optimization algorithm is initialized, how a single iteration is performed and when and how to terminate the iterations.
=======
A solver is anything that implements the [Solver](https://docs.rs/argmin/latest/argmin/core/trait.Solver.html) trait. This trait defines how the optimization algorithm is initialized, how a single iteration is performed and when and how to terminate the iterations.
>>>>>>> 90a03788 (Adds CMA-ES algorithm):docs/book/src/concept.md

The optimization problem needs to implement a subset of the traits
[`CostFunction`](https://docs.rs/argmin/latest/argmin/core/trait.CostFunction.html),
Expand Down

0 comments on commit b819f00

Please sign in to comment.