Skip to content

Commit

Permalink
new indirection and project structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Siel committed Oct 3, 2023
1 parent e2a35a3 commit f2609e2
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 282 deletions.
8 changes: 3 additions & 5 deletions examples/bimodal_ke/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::collections::HashMap;

use eyre::Result;
use npcore::{
prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
},
use npcore::prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
start,
};
use ode_solvers::*;
Expand Down
2 changes: 1 addition & 1 deletion examples/simulator/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use eyre::Result;
use npcore::prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
simulator::simulate,
simulate,
};
use ode_solvers::*;
use std::collections::HashMap;
Expand Down
8 changes: 3 additions & 5 deletions examples/two_eq_lag/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use eyre::Result;
use npcore::{
prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
},
use npcore::prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
start,
};
use ode_solvers::*;
Expand Down
8 changes: 3 additions & 5 deletions examples/vori/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#![allow(dead_code)]
#![allow(unused_variables)]
use eyre::Result;
use npcore::{
prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
},
use npcore::prelude::{
datafile::{CovLine, Infusion, Scenario},
predict::{Engine, Predict},
start,
};
const ATOL: f64 = 1e-4;
Expand Down
75 changes: 44 additions & 31 deletions src/algorithms.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,67 @@
use crate::prelude::{self, datafile::Scenario, output::NPCycle, settings::run::Data};
use ndarray::Array2;
use crate::prelude::{self, output::NPCycle, settings::run::Data};

use output::NPResult;
use prelude::*;
use simulation::predict::{Engine, Predict};
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::mpsc;

pub mod npag;
mod npag;
mod postprob;

pub enum Type {
NPAG,
POSTPROB,
}

pub trait Algorithm<S> {
// fn initialize(
// self,
// sim_eng: Engine<S>,
// ranges: Vec<(f64, f64)>,
// theta: Array2<f64>,
// scenarios: Vec<Scenario>,
// c: (f64, f64, f64, f64),
// tx: UnboundedSender<NPCycle>,
// settings: Data,
// ) -> Self
// where
// S: Predict + std::marker::Sync;
pub trait Algorithm {
fn fit(&mut self) -> NPResult;
fn to_npresult(&self) -> NPResult;
}

pub fn initialize_algorithm<S>(
alg_type: Type,
sim_eng: Engine<S>,
ranges: Vec<(f64, f64)>,
theta: Array2<f64>,
scenarios: Vec<Scenario>,
c: (f64, f64, f64, f64),
tx: UnboundedSender<NPCycle>,
engine: Engine<S>,
settings: Data,
) -> Box<dyn Algorithm<S>>
tx: mpsc::UnboundedSender<NPCycle>,
) -> Box<dyn Algorithm>
where
S: Predict + std::marker::Sync + 'static + Clone,
S: Predict + std::marker::Sync + Clone + 'static,
{
match alg_type {
Type::NPAG => Box::new(npag::NPAG::new(
sim_eng, ranges, theta, scenarios, c, tx, settings,
if std::path::Path::new("stop").exists() {
match std::fs::remove_file("stop") {
Ok(_) => log::info!("Removed previous stop file"),
Err(err) => panic!("Unable to remove previous stop file: {}", err),
}
}
let ranges = settings.computed.random.ranges.clone();
let theta = initialization::sample_space(&settings, &ranges);
let mut scenarios = datafile::parse(&settings.parsed.paths.data).unwrap();
if let Some(exclude) = &settings.parsed.config.exclude {
for val in exclude {
scenarios.remove(val.as_integer().unwrap() as usize);
}
}
//This should be a macro, so it can automatically expands as soon as we add a new option in the Type Enum
match settings.parsed.config.engine.as_str() {
"NPAG" => Box::new(npag::NPAG::new(
engine,
ranges,
theta,
scenarios,
settings.parsed.error.poly,
tx,
settings,
)),
Type::POSTPROB => Box::new(postprob::POSTPROB::new(
sim_eng, theta, scenarios, c, tx, settings,
"POSTPROB" => Box::new(postprob::POSTPROB::new(
engine,
theta,
scenarios,
settings.parsed.error.poly,
tx,
settings,
)),
alg => {
eprintln!("Error: Algorithm not recognized: {}", alg);
std::process::exit(-1)
}
}
}
10 changes: 1 addition & 9 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ where
settings: Data,
}

impl<S> Algorithm<S> for NPAG<S>
impl<S> Algorithm for NPAG<S>
where
S: Predict + std::marker::Sync + Clone,
{
Expand Down Expand Up @@ -194,14 +194,6 @@ where
}

pub fn run(&mut self) -> NPResult {
// TODO: Move to an initialization routine?
if std::path::Path::new("stop").exists() {
match std::fs::remove_file("stop") {
Ok(_) => log::info!("Removed previous stop file"),
Err(err) => panic!("Unable to remove previous stop file: {}", err),
}
}

while self.eps > THETA_E {
// log::info!("Cycle: {}", cycle);
// psi n_sub rows, nspp columns
Expand Down
5 changes: 1 addition & 4 deletions src/algorithms/postprob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ where
settings: Data,
}

impl<S> Algorithm<S> for POSTPROB<S>
impl<S> Algorithm for POSTPROB<S>
where
S: Predict + std::marker::Sync + Clone,
{
Expand Down Expand Up @@ -95,7 +95,6 @@ where

pub fn run(&mut self) -> NPResult {
let ypred = sim_obs(&self.engine, &self.scenarios, &self.theta, false);

self.psi = prob::calculate_psi(
&ypred,
&self.scenarios,
Expand All @@ -105,11 +104,9 @@ where
e_type: &self.error_type,
},
);

let (w, objf) = ipm::burke(&self.psi).expect("Error in IPM");
self.w = w;
self.objf = objf;

self.to_npresult()
}
}
104 changes: 104 additions & 0 deletions src/entrypoints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use crate::algorithms::initialize_algorithm;
use crate::prelude::output::NPCycle;
use crate::prelude::{
output::NPResult,
predict::{Engine, Predict},
settings::run::Data,
*,
};
use csv::{ReaderBuilder, WriterBuilder};
use eyre::Result;

use log::LevelFilter;
use log4rs::append::file::FileAppender;
use log4rs::config::{Appender, Config, Root};
use log4rs::encode::pattern::PatternEncoder;
use ndarray::Array2;
use ndarray_csv::Array2Reader;
use predict::sim_obs;
use std::fs;
use std::fs::File;
use std::thread::spawn;
use std::time::Instant;
use tokio::sync::mpsc::{self};

pub fn simulate<S>(engine: Engine<S>, settings_path: String) -> Result<()>
where
S: Predict + std::marker::Sync + std::marker::Send + 'static + Clone,
{
let settings = settings::simulator::read(settings_path);
let theta_file = File::open(settings.paths.theta).unwrap();
let mut reader = ReaderBuilder::new()
.has_headers(false)
.from_reader(theta_file);
let theta: Array2<f64> = reader.deserialize_array2_dynamic().unwrap();
let scenarios = datafile::parse(&settings.paths.data).unwrap();

let ypred = sim_obs(&engine, &scenarios, &theta, false);

let sim_file = File::create("simulation_output.csv").unwrap();
let mut sim_writer = WriterBuilder::new()
.has_headers(false)
.from_writer(sim_file);
sim_writer
.write_record(["id", "point", "time", "sim_obs"])
.unwrap();
for (id, scenario) in scenarios.iter().enumerate() {
let time = scenario.obs_times.clone();
for (point, _spp) in theta.rows().into_iter().enumerate() {
for (i, time) in time.iter().enumerate() {
sim_writer.write_record(&[
id.to_string(),
point.to_string(),
time.to_string(),
ypred.get((id, point)).unwrap().get(i).unwrap().to_string(),
])?;
}
}
}
Ok(())
}
pub fn start<S>(engine: Engine<S>, settings_path: String) -> Result<NPResult>
where
S: Predict + std::marker::Sync + std::marker::Send + 'static + Clone,
{
let now = Instant::now();
let settings = settings::run::read(settings_path);
setup_log(&settings);
let (tx, rx) = mpsc::unbounded_channel::<NPCycle>();
let mut algorithm = initialize_algorithm(engine.clone(), settings.clone(), tx);
// Spawn new thread for TUI
let settings_tui = settings.clone();
if settings.parsed.config.tui {
let _ui_handle = spawn(move || {
start_ui(rx, settings_tui).expect("Failed to start TUI");
});
}

let result = algorithm.fit();
log::info!("Total time: {:.2?}", now.elapsed());

if let Some(write) = &settings.parsed.config.pmetrics_outputs {
result.write_outputs(*write, &engine);
}

Ok(result)
}

//TODO: move elsewhere
fn setup_log(settings: &Data) {
if let Some(log_path) = &settings.parsed.paths.log_out {
if fs::remove_file(log_path).is_ok() {};
let logfile = FileAppender::builder()
.encoder(Box::new(PatternEncoder::new("{l}: {m}\n")))
.build(log_path)
.unwrap();

let config = Config::builder()
.appender(Appender::builder().build("logfile", Box::new(logfile)))
.build(Root::builder().appender("logfile").build(LevelFilter::Info))
.unwrap();

log4rs::init_config(config).unwrap();
};
}
Loading

0 comments on commit f2609e2

Please sign in to comment.