-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new indirection and project structure
- Loading branch information
Showing
12 changed files
with
263 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,67 @@ | ||
use crate::prelude::{self, datafile::Scenario, output::NPCycle, settings::run::Data}; | ||
use ndarray::Array2; | ||
use crate::prelude::{self, output::NPCycle, settings::run::Data}; | ||
|
||
use output::NPResult; | ||
use prelude::*; | ||
use simulation::predict::{Engine, Predict}; | ||
use tokio::sync::mpsc::UnboundedSender; | ||
use tokio::sync::mpsc; | ||
|
||
pub mod npag; | ||
mod npag; | ||
mod postprob; | ||
|
||
pub enum Type { | ||
NPAG, | ||
POSTPROB, | ||
} | ||
|
||
pub trait Algorithm<S> { | ||
// fn initialize( | ||
// self, | ||
// sim_eng: Engine<S>, | ||
// ranges: Vec<(f64, f64)>, | ||
// theta: Array2<f64>, | ||
// scenarios: Vec<Scenario>, | ||
// c: (f64, f64, f64, f64), | ||
// tx: UnboundedSender<NPCycle>, | ||
// settings: Data, | ||
// ) -> Self | ||
// where | ||
// S: Predict + std::marker::Sync; | ||
pub trait Algorithm { | ||
fn fit(&mut self) -> NPResult; | ||
fn to_npresult(&self) -> NPResult; | ||
} | ||
|
||
pub fn initialize_algorithm<S>( | ||
alg_type: Type, | ||
sim_eng: Engine<S>, | ||
ranges: Vec<(f64, f64)>, | ||
theta: Array2<f64>, | ||
scenarios: Vec<Scenario>, | ||
c: (f64, f64, f64, f64), | ||
tx: UnboundedSender<NPCycle>, | ||
engine: Engine<S>, | ||
settings: Data, | ||
) -> Box<dyn Algorithm<S>> | ||
tx: mpsc::UnboundedSender<NPCycle>, | ||
) -> Box<dyn Algorithm> | ||
where | ||
S: Predict + std::marker::Sync + 'static + Clone, | ||
S: Predict + std::marker::Sync + Clone + 'static, | ||
{ | ||
match alg_type { | ||
Type::NPAG => Box::new(npag::NPAG::new( | ||
sim_eng, ranges, theta, scenarios, c, tx, settings, | ||
if std::path::Path::new("stop").exists() { | ||
match std::fs::remove_file("stop") { | ||
Ok(_) => log::info!("Removed previous stop file"), | ||
Err(err) => panic!("Unable to remove previous stop file: {}", err), | ||
} | ||
} | ||
let ranges = settings.computed.random.ranges.clone(); | ||
let theta = initialization::sample_space(&settings, &ranges); | ||
let mut scenarios = datafile::parse(&settings.parsed.paths.data).unwrap(); | ||
if let Some(exclude) = &settings.parsed.config.exclude { | ||
for val in exclude { | ||
scenarios.remove(val.as_integer().unwrap() as usize); | ||
} | ||
} | ||
//This should be a macro, so it can automatically expands as soon as we add a new option in the Type Enum | ||
match settings.parsed.config.engine.as_str() { | ||
"NPAG" => Box::new(npag::NPAG::new( | ||
engine, | ||
ranges, | ||
theta, | ||
scenarios, | ||
settings.parsed.error.poly, | ||
tx, | ||
settings, | ||
)), | ||
Type::POSTPROB => Box::new(postprob::POSTPROB::new( | ||
sim_eng, theta, scenarios, c, tx, settings, | ||
"POSTPROB" => Box::new(postprob::POSTPROB::new( | ||
engine, | ||
theta, | ||
scenarios, | ||
settings.parsed.error.poly, | ||
tx, | ||
settings, | ||
)), | ||
alg => { | ||
eprintln!("Error: Algorithm not recognized: {}", alg); | ||
std::process::exit(-1) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
use crate::algorithms::initialize_algorithm; | ||
use crate::prelude::output::NPCycle; | ||
use crate::prelude::{ | ||
output::NPResult, | ||
predict::{Engine, Predict}, | ||
settings::run::Data, | ||
*, | ||
}; | ||
use csv::{ReaderBuilder, WriterBuilder}; | ||
use eyre::Result; | ||
|
||
use log::LevelFilter; | ||
use log4rs::append::file::FileAppender; | ||
use log4rs::config::{Appender, Config, Root}; | ||
use log4rs::encode::pattern::PatternEncoder; | ||
use ndarray::Array2; | ||
use ndarray_csv::Array2Reader; | ||
use predict::sim_obs; | ||
use std::fs; | ||
use std::fs::File; | ||
use std::thread::spawn; | ||
use std::time::Instant; | ||
use tokio::sync::mpsc::{self}; | ||
|
||
pub fn simulate<S>(engine: Engine<S>, settings_path: String) -> Result<()> | ||
where | ||
S: Predict + std::marker::Sync + std::marker::Send + 'static + Clone, | ||
{ | ||
let settings = settings::simulator::read(settings_path); | ||
let theta_file = File::open(settings.paths.theta).unwrap(); | ||
let mut reader = ReaderBuilder::new() | ||
.has_headers(false) | ||
.from_reader(theta_file); | ||
let theta: Array2<f64> = reader.deserialize_array2_dynamic().unwrap(); | ||
let scenarios = datafile::parse(&settings.paths.data).unwrap(); | ||
|
||
let ypred = sim_obs(&engine, &scenarios, &theta, false); | ||
|
||
let sim_file = File::create("simulation_output.csv").unwrap(); | ||
let mut sim_writer = WriterBuilder::new() | ||
.has_headers(false) | ||
.from_writer(sim_file); | ||
sim_writer | ||
.write_record(["id", "point", "time", "sim_obs"]) | ||
.unwrap(); | ||
for (id, scenario) in scenarios.iter().enumerate() { | ||
let time = scenario.obs_times.clone(); | ||
for (point, _spp) in theta.rows().into_iter().enumerate() { | ||
for (i, time) in time.iter().enumerate() { | ||
sim_writer.write_record(&[ | ||
id.to_string(), | ||
point.to_string(), | ||
time.to_string(), | ||
ypred.get((id, point)).unwrap().get(i).unwrap().to_string(), | ||
])?; | ||
} | ||
} | ||
} | ||
Ok(()) | ||
} | ||
pub fn start<S>(engine: Engine<S>, settings_path: String) -> Result<NPResult> | ||
where | ||
S: Predict + std::marker::Sync + std::marker::Send + 'static + Clone, | ||
{ | ||
let now = Instant::now(); | ||
let settings = settings::run::read(settings_path); | ||
setup_log(&settings); | ||
let (tx, rx) = mpsc::unbounded_channel::<NPCycle>(); | ||
let mut algorithm = initialize_algorithm(engine.clone(), settings.clone(), tx); | ||
// Spawn new thread for TUI | ||
let settings_tui = settings.clone(); | ||
if settings.parsed.config.tui { | ||
let _ui_handle = spawn(move || { | ||
start_ui(rx, settings_tui).expect("Failed to start TUI"); | ||
}); | ||
} | ||
|
||
let result = algorithm.fit(); | ||
log::info!("Total time: {:.2?}", now.elapsed()); | ||
|
||
if let Some(write) = &settings.parsed.config.pmetrics_outputs { | ||
result.write_outputs(*write, &engine); | ||
} | ||
|
||
Ok(result) | ||
} | ||
|
||
//TODO: move elsewhere | ||
fn setup_log(settings: &Data) { | ||
if let Some(log_path) = &settings.parsed.paths.log_out { | ||
if fs::remove_file(log_path).is_ok() {}; | ||
let logfile = FileAppender::builder() | ||
.encoder(Box::new(PatternEncoder::new("{l}: {m}\n"))) | ||
.build(log_path) | ||
.unwrap(); | ||
|
||
let config = Config::builder() | ||
.appender(Appender::builder().build("logfile", Box::new(logfile))) | ||
.build(Root::builder().appender("logfile").build(LevelFilter::Info)) | ||
.unwrap(); | ||
|
||
log4rs::init_config(config).unwrap(); | ||
}; | ||
} |
Oops, something went wrong.