diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/io/mobility_io.h b/pycode/memilio-simulation/memilio/simulation/bindings/io/mobility_io.h index 23b6dbe053..694825d527 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/io/mobility_io.h +++ b/pycode/memilio-simulation/memilio/simulation/bindings/io/mobility_io.h @@ -22,6 +22,7 @@ #include "memilio/mobility/metapopulation_mobility_instant.h" #include "memilio/io/mobility_io.h" +#include "pybind_util.h" #include "pybind11/pybind11.h" #include @@ -32,11 +33,27 @@ namespace pymio template void bind_write_graph(pybind11::module_& m) { - m.def("write_graph", - [&](const mio::Graph>& graph, const std::string& directory) { - int ioflags = mio::IOF_None; - auto ioresult = mio::write_graph(graph, directory, ioflags); - }); + m.def( + "write_graph", + [&](const mio::Graph>& graph, const std::string& directory) { + int ioflags = mio::IOF_None; + auto ioresult = mio::write_graph(graph, directory, ioflags); + }, + "Write a graph (nodes and edges) as JSON files to the given directory.", pybind11::arg("graph"), + pybind11::arg("directory")); +} + +template +void bind_read_graph(pybind11::module_& m) +{ + m.def( + "read_graph", + [&](const std::string& directory) { + auto result = mio::read_graph(directory, 0, true); + return pymio::check_and_throw(result); + }, + "Read a graph from JSON files in the given directory (see write_graph).", pybind11::arg("directory"), + pybind11::return_value_policy::move); } } // namespace pymio diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp index cd8d6fcbe6..273da48f2d 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecir.cpp @@ -287,11 +287,11 @@ PYBIND11_MODULE(_simulation_osecir, m) mio::osecir::InfectionState::InfectedSymptoms, mio::osecir::InfectionState::Recovered}; auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; auto result = mio::set_edges, mio::MobilityParameters, - mio::MobilityCoefficientGroup, mio::osecir::InfectionState, - decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph, - mobile_comp, contact_locations_size, - mio::read_mobility_plain, weights); + ContactLocation, mio::osecir::Model, mio::MobilityParameters, + mio::MobilityCoefficientGroup, mio::osecir::InfectionState, + decltype(mio::read_mobility_plain)>(mobility_data_file, params_graph, + mobile_comp, contact_locations_size, + mio::read_mobility_plain, weights); return pymio::check_and_throw(result); }, py::return_value_policy::move); @@ -302,6 +302,7 @@ PYBIND11_MODULE(_simulation_osecir, m) #ifdef MEMILIO_HAS_JSONCPP pymio::bind_write_graph>(m); + pymio::bind_read_graph>(m); m.def( "read_input_data_county", [](std::vector>& model, mio::Date date, const std::vector& county, @@ -314,8 +315,9 @@ PYBIND11_MODULE(_simulation_osecir, m) py::return_value_policy::move); #endif // MEMILIO_HAS_JSONCPP - m.def("interpolate_simulation_result", py::overload_cast( - &mio::interpolate_simulation_result>)); + m.def("interpolate_simulation_result", + py::overload_cast( + &mio::interpolate_simulation_result>)); m.def("interpolate_ensemble_results", &mio::interpolate_ensemble_results); diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp index c7919b4237..8e1e922404 100755 --- a/pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/models/osecirvvs.cpp @@ -340,9 +340,9 @@ PYBIND11_MODULE(_simulation_osecirvvs, m) mio::osecirvvs::InfectionState::InfectedSymptomsImprovedImmunity}; auto weights = std::vector{0., 0., 1.0, 1.0, 0.33, 0., 0.}; auto result = mio::set_edges, - mio::MobilityParameters, mio::MobilityCoefficientGroup, - mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>( + ContactLocation, mio::osecirvvs::Model, + mio::MobilityParameters, mio::MobilityCoefficientGroup, + mio::osecirvvs::InfectionState, decltype(mio::read_mobility_plain)>( mobility_data_file, params_graph, mobile_comp, contact_locations_size, mio::read_mobility_plain, weights); return pymio::check_and_throw(result); @@ -355,6 +355,7 @@ PYBIND11_MODULE(_simulation_osecirvvs, m) #ifdef MEMILIO_HAS_JSONCPP pymio::bind_write_graph>(m); + pymio::bind_read_graph>(m); m.def( "read_input_data_county", [](std::vector>& model, mio::Date date, const std::vector& county, diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_mobility.py b/pycode/memilio-simulation/memilio/simulation_test/test_mobility.py index 62ca7cce8f..92e80b4301 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_mobility.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_mobility.py @@ -18,6 +18,7 @@ # limitations under the License. ############################################################################# import unittest +import tempfile import numpy as np @@ -75,6 +76,36 @@ def test_mobility_sim(self): self.assertGreaterEqual(sim.graph.get_node( 0).property.result.get_num_time_points(), 3) + def test_write_read_graph_simple(self): + # build a simple model graph + model = osecir.Model(1) + model.parameters.TestAndTraceCapacity.value = 42 + model.apply_constraints() + + graph = osecir.ModelGraph() + graph.add_node(0, model) + graph.add_node(1, model) + + num_compartments = 10 + graph.add_edge(0, 1, 0.1 * np.ones(num_compartments)) + graph.add_edge(1, 0, 0.1 * np.ones(num_compartments)) + + with tempfile.TemporaryDirectory() as tmpdir: + # save graph + osecir.write_graph(graph, tmpdir) + # read graph back + g_read = osecir.read_graph(tmpdir) + + # basic structure + self.assertEqual(graph.num_nodes, g_read.num_nodes) + self.assertEqual(graph.num_edges, g_read.num_edges) + + # check one parameter + self.assertEqual( + g_read.get_node(0).property.parameters.TestAndTraceCapacity.value, + 42, + ) + if __name__ == '__main__': unittest.main()