Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "memilio/mobility/metapopulation_mobility_instant.h"
#include "memilio/io/mobility_io.h"
#include "pybind_util.h"

#include "pybind11/pybind11.h"
#include <cstddef>
Expand All @@ -32,11 +33,27 @@ namespace pymio
template <class Model>
void bind_write_graph(pybind11::module_& m)
{
m.def("write_graph",
[&](const mio::Graph<Model, mio::MobilityParameters<double>>& graph, const std::string& directory) {
int ioflags = mio::IOF_None;
auto ioresult = mio::write_graph<double, Model>(graph, directory, ioflags);
});
m.def(
"write_graph",
[&](const mio::Graph<Model, mio::MobilityParameters<double>>& graph, const std::string& directory) {
int ioflags = mio::IOF_None;
auto ioresult = mio::write_graph<double, Model>(graph, directory, ioflags);
},
"Write a graph (nodes and edges) as JSON files to the given directory.", pybind11::arg("graph"),
pybind11::arg("directory"));
}

template <class Model>
void bind_read_graph(pybind11::module_& m)
{
m.def(
"read_graph",
[&](const std::string& directory) {
auto result = mio::read_graph<double, Model>(directory, 0, true);
return pymio::check_and_throw(result);
},
"Read a graph from JSON files in the given directory (see write_graph).", pybind11::arg("directory"),
pybind11::return_value_policy::move);
}

} // namespace pymio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ PYBIND11_MODULE(_simulation_osecir, m)
mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered};
auto weights = std::vector<ScalarType>{0., 0., 1.0, 1.0, 0.33, 0., 0.};
auto result = mio::set_edges<double, // FP
ContactLocation, mio::osecir::Model<double>, mio::MobilityParameters<double>,
mio::MobilityCoefficientGroup<double>, mio::osecir::InfectionState,
decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph,
mobile_comp, contact_locations_size,
mio::read_mobility_plain, weights);
ContactLocation, mio::osecir::Model<double>, mio::MobilityParameters<double>,
mio::MobilityCoefficientGroup<double>, mio::osecir::InfectionState,
decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph,
mobile_comp, contact_locations_size,
mio::read_mobility_plain, weights);
return pymio::check_and_throw(result);
},
py::return_value_policy::move);
Expand All @@ -302,6 +302,7 @@ PYBIND11_MODULE(_simulation_osecir, m)

#ifdef MEMILIO_HAS_JSONCPP
pymio::bind_write_graph<mio::osecir::Model<double>>(m);
pymio::bind_read_graph<mio::osecir::Model<double>>(m);
m.def(
"read_input_data_county",
[](std::vector<mio::osecir::Model<double>>& model, mio::Date date, const std::vector<int>& county,
Expand All @@ -314,8 +315,9 @@ PYBIND11_MODULE(_simulation_osecir, m)
py::return_value_policy::move);
#endif // MEMILIO_HAS_JSONCPP

m.def("interpolate_simulation_result", py::overload_cast<const MobilityGraph&>(
&mio::interpolate_simulation_result<double, mio::osecir::Simulation<double>>));
m.def("interpolate_simulation_result",
py::overload_cast<const MobilityGraph&>(
&mio::interpolate_simulation_result<double, mio::osecir::Simulation<double>>));

m.def("interpolate_ensemble_results", &mio::interpolate_ensemble_results<MobilityGraph>);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ PYBIND11_MODULE(_simulation_osecirvvs, m)
mio::osecirvvs::InfectionState::InfectedSymptomsImprovedImmunity};
auto weights = std::vector<ScalarType>{0., 0., 1.0, 1.0, 0.33, 0., 0.};
auto result = mio::set_edges<double, // FP,
ContactLocation, mio::osecirvvs::Model<double>,
mio::MobilityParameters<double>, mio::MobilityCoefficientGroup<double>,
mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>(
ContactLocation, mio::osecirvvs::Model<double>,
mio::MobilityParameters<double>, mio::MobilityCoefficientGroup<double>,
mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>(
mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain,
weights);
return pymio::check_and_throw(result);
Expand All @@ -355,6 +355,7 @@ PYBIND11_MODULE(_simulation_osecirvvs, m)

#ifdef MEMILIO_HAS_JSONCPP
pymio::bind_write_graph<mio::osecirvvs::Model<double>>(m);
pymio::bind_read_graph<mio::osecirvvs::Model<double>>(m);
m.def(
"read_input_data_county",
[](std::vector<mio::osecirvvs::Model<double>>& model, mio::Date date, const std::vector<int>& county,
Expand Down
31 changes: 31 additions & 0 deletions pycode/memilio-simulation/memilio/simulation_test/test_mobility.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# limitations under the License.
#############################################################################
import unittest
import tempfile

import numpy as np

Expand Down Expand Up @@ -75,6 +76,36 @@ def test_mobility_sim(self):
self.assertGreaterEqual(sim.graph.get_node(
0).property.result.get_num_time_points(), 3)

def test_write_read_graph_simple(self):
# build a simple model graph
model = osecir.Model(1)
model.parameters.TestAndTraceCapacity.value = 42
model.apply_constraints()

graph = osecir.ModelGraph()
graph.add_node(0, model)
graph.add_node(1, model)

num_compartments = 10
graph.add_edge(0, 1, 0.1 * np.ones(num_compartments))
graph.add_edge(1, 0, 0.1 * np.ones(num_compartments))

with tempfile.TemporaryDirectory() as tmpdir:
# save graph
osecir.write_graph(graph, tmpdir)
# read graph back
g_read = osecir.read_graph(tmpdir)

# basic structure
self.assertEqual(graph.num_nodes, g_read.num_nodes)
self.assertEqual(graph.num_edges, g_read.num_edges)

# check one parameter
self.assertEqual(
g_read.get_node(0).property.parameters.TestAndTraceCapacity.value,
42,
)


if __name__ == '__main__':
unittest.main()