Skip to content

Commit

Permalink
simpler atol handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Siel committed May 7, 2024
1 parent e8fe85d commit e591c24
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 25 deletions.
27 changes: 4 additions & 23 deletions src/simulator/ode/diffsol_traits.rs
Expand Up @@ -3,32 +3,29 @@ use anyhow::{Ok, Result};
use diffsol::{
matrix::Matrix,
ode_solver::{equations::OdeSolverEquations, problem::OdeSolverProblem},
op::{unit::UnitCallable, Op},
op::unit::UnitCallable,
vector::Vector,
OdeEquations,
};

use std::rc::Rc;

use super::closure::PMClosure;

pub fn build_pm_ode<M, F, I, Ite, T>(
pub fn build_pm_ode<M, F, I>(
rhs: F,
init: I,
p: M::V,
t0: f64,
h0: f64,
rtol: f64,
atol: Ite,
atol: f64,
cov: Covariates,
infusions: Vec<Infusion>,
) -> Result<OdeSolverProblem<OdeSolverEquations<M, PMClosure<M, F>, I>>>
where
M: Matrix,
F: Fn(&M::V, &M::V, M::T, &mut M::V, M::V, &Covariates),
I: Fn(&M::V, M::T) -> M::V,
Ite: IntoIterator<Item = T>,
f64: From<T>,
{
let p = Rc::new(p);
let t0 = M::T::from(t0);
Expand All @@ -38,8 +35,7 @@ where
let mass = Rc::new(UnitCallable::new(nstates));
let rhs = Rc::new(rhs);
let eqn = OdeSolverEquations::new(rhs, mass, None, init, p, false);
let atol = atol.into_iter().map(|x| f64::from(x)).collect();
let atol = build_atol(atol, eqn.rhs().nstates())?;
let atol = M::V::from_element(nstates, M::T::from(atol));
Ok(OdeSolverProblem::new(
eqn,
M::T::from(rtol),
Expand All @@ -48,18 +44,3 @@ where
M::T::from(h0),
))
}
fn build_atol<V: Vector>(atol: Vec<f64>, nstates: usize) -> Result<V> {
if atol.len() == 1 {
Ok(V::from_element(nstates, V::T::from(atol[0])))
} else if atol.len() != nstates {
Err(anyhow::anyhow!(
"atol must have length 1 or equal to the number of states"
))
} else {
let mut v = V::zeros(nstates);
for (i, &a) in atol.iter().enumerate() {
v[i] = V::T::from(a);
}
Ok(v)
}
}
4 changes: 2 additions & 2 deletions src/simulator/ode/mod.rs
Expand Up @@ -22,14 +22,14 @@ pub fn simulate_ode_event(
ti: f64,
tf: f64,
) -> V {
let problem = build_pm_ode::<M, _, _, _, _>(
let problem = build_pm_ode::<M, _, _>(
diffeq.clone(),
move |_p: &V, _t: T| x.clone(),
V::from_vec(support_point.to_vec()),
ti,
1.0,
RTOL,
[ATOL],
ATOL,
cov.clone(),
infusions.clone(),
)
Expand Down

0 comments on commit e591c24

Please sign in to comment.