diff --git a/python/pyarts/workspace/workspace.py b/python/pyarts/workspace/workspace.py index c1a5bd5e26..7a5ffe2a84 100644 --- a/python/pyarts/workspace/workspace.py +++ b/python/pyarts/workspace/workspace.py @@ -373,3 +373,9 @@ def __delattr__(self, attr): raise AttributeError("You cannot delete __class__") getattr(self, attr).delete_level() + + def __copy__(self): + return Workspace(super().__copy__()) + + def __deepcopy__(self, *args): + return Workspace(super().__deepcopy__(*args)) diff --git a/python/test/workspace/test_workspace.py b/python/test/workspace/test_workspace.py new file mode 100644 index 0000000000..7685eea065 --- /dev/null +++ b/python/test/workspace/test_workspace.py @@ -0,0 +1,52 @@ +""" +Test handling of workspace of the Python interface. +""" +import copy +import pytest +import pyarts +from pyarts.workspace import Workspace, arts_agenda +from pyarts.arts import Index + +class TestWorkspace: + def setup_method(self): + pass + + def test_copy(self): + ws = Workspace() + ws.aaavar = Index(5) + + ws2 = copy.copy(ws) + ws2.aaavar = 6 + + assert ws2.aaavar.value == ws.aaavar.value + + def test_deepcopy(self): + ws = Workspace() + ws.aaavar = Index(5) + + ws2 = copy.deepcopy(ws) + ws2.aaavar = 6 + + assert ws2.aaavar.value != ws.aaavar.value + + def test_copy_agenda(self): + ws = Workspace() + ws.aaavar = Index(5) + + @arts_agenda(ws=ws, set_agenda=True) + def test_agenda(ws): + ws.Print(ws.aaavar, 0) + ws.AgendaExecute(ws.test_agenda) + + ws2 = copy.deepcopy(ws) + ws2.aaavar = 4 + ws2.AgendaExecute(ws2.test_agenda) + + +if __name__ == "__main__": + ta = TestWorkspace() + ta.setup_method() + ta.test_copy() + ta.test_deepcopy() + ta.test_copy_agenda() + diff --git a/src/agenda_class.cc b/src/agenda_class.cc index 3b9d7f98a0..73f6403aad 100644 --- a/src/agenda_class.cc +++ b/src/agenda_class.cc @@ -39,6 +39,7 @@ #include "global_data.h" #include "messages.h" #include "methods.h" +#include "tokval.h" #include "workspace_ng.h" MRecord::MRecord() : moutput(), minput(), msetvalue(), mtasks() {} @@ -911,3 +912,82 @@ ostream& operator<<(ostream& os, const MRecord& a) { a.print(os, ""); return os; } + +ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas) { + ArrayOfAgenda out; + for (auto& ag : agendas) out.push_back(ag.deepcopy_if(ws)); + return out; +} + +// Method to share indices of variables from one workspace to another +ArrayOfIndex make_same_wsvs(Workspace& ws_out, + const Workspace& ws_in, + const ArrayOfIndex& vars) { + ArrayOfIndex out; + + out.reserve(vars.size()); + for (auto& v : vars) { + + // Set the value position + if (v < ws_out.nelem()) { + if (ws_in.wsv_data_ptr->at(v).Name() == + ws_out.wsv_data_ptr->at(v).Name() and + ws_in.wsv_data_ptr->at(v).Group() == + ws_out.wsv_data_ptr->at(v).Group()) { + out.push_back(v); // This is the only tested path! + } else { + out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v))); + } + } else { + out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v))); + } + + // Update if the wsv holds an agenda default value + if (auto& wsv = ws_out.wsv_data_ptr->at(out.back()); wsv.has_defaults()) { + auto& val = wsv.default_value(); + if (wsv.Group() == WorkspaceGroupIndexValue) + wsv.update_default_value(Agenda(val).deepcopy_if(ws_out)); + else if (wsv.Group() == WorkspaceGroupIndexValue) + wsv.update_default_value(deepcopy_if(ws_out, val)); + } + } + + return out; +} + +MRecord MRecord::deepcopy_if(Workspace& workspace) const { + if (mtasks.has_same_origin(workspace)) return *this; + + MRecord out(workspace); + out.mid = mid; + out.moutput = make_same_wsvs(workspace, mtasks.workspace(), moutput); + out.minput = make_same_wsvs(workspace, mtasks.workspace(), minput); + + // Must update if the value is an agenda + if (msetvalue.holdsAgenda()) { + out.msetvalue = Agenda(msetvalue).deepcopy_if(workspace); + } else if (msetvalue.holdsArrayOfAgenda()) { + out.msetvalue = ::deepcopy_if(workspace, msetvalue); + } else { + out.msetvalue = msetvalue; + } + + out.mtasks = mtasks.deepcopy_if(workspace); + out.minternal = minternal; + + return out; +} + +Agenda Agenda::deepcopy_if(Workspace& workspace) const { + if (has_same_origin(workspace)) return *this; + + Agenda out(workspace); + out.mname = mname; + for (auto& method : mml) out.mml.push_back(method.deepcopy_if(workspace)); + out.moutput_push = make_same_wsvs(workspace, *ws, moutput_push); + out.moutput_dup = make_same_wsvs(workspace, *ws, moutput_dup); + out.main_agenda = main_agenda; + out.mchecked = mchecked; + + return out; +} diff --git a/src/agenda_class.h b/src/agenda_class.h index 1d370ecefb..6e12e58945 100644 --- a/src/agenda_class.h +++ b/src/agenda_class.h @@ -108,6 +108,9 @@ class Agenda final { [[nodiscard]] Workspace& workspace() {return *ws;} [[nodiscard]] const Workspace& workspace() const {return *ws;} + //! Creates a deep copy of the agenda if necessary (i.e., different workspace)! + Agenda deepcopy_if(Workspace&) const; + private: std::shared_ptr ws; /*!< The workspace upon which this Agenda lives. */ String mname; /*!< Agenda name. */ @@ -153,6 +156,9 @@ class MRecord { [[nodiscard]] const TokVal& SetValue() const { return msetvalue; } [[nodiscard]] const Agenda& Tasks() const { return mtasks; } + //! Creates a deep copy of the method if necessary (i.e., different workspace)! + MRecord deepcopy_if(Workspace&) const; + //! Indicates the origin of this method. /*! Returns true if this method originates from a controlfile and false @@ -222,4 +228,7 @@ class MRecord { /** An array of Agenda. */ using ArrayOfAgenda = Array; +//! Same as Agenda member method but for an entire array +ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas); + #endif diff --git a/src/make_tokval.cc b/src/make_tokval.cc index 7ee099b122..0813c0af37 100644 --- a/src/make_tokval.cc +++ b/src/make_tokval.cc @@ -65,6 +65,17 @@ concept ArtsTypeConstRef = ArtsType and std::is_same_v concept ArtsTypeBase = ArtsType and std::is_same_v, T>; +template struct WorkspaceGroupIndex { static constexpr Index value=-1; }; +)--"; + + for (Index i = 0; i < global_data::wsv_groups.nelem(); i++) + file_h << "template <> struct WorkspaceGroupIndex<" + << global_data::wsv_groups[i] + << "> { static constexpr Index value=" << i << "; };\n"; + + file_h << R"--( +template inline constexpr Index WorkspaceGroupIndexValue = WorkspaceGroupIndex::value; + class TokVal { void * ptr{nullptr}; public: diff --git a/src/python_interface/py_workspace.cpp b/src/python_interface/py_workspace.cpp index 291eec3292..9678b40aa7 100644 --- a/src/python_interface/py_workspace.cpp +++ b/src/python_interface/py_workspace.cpp @@ -29,6 +29,17 @@ void py_workspace(py::module_& m, }), py::arg("verbosity") = 0, py::arg("agenda_verbosity") = 0) + .def(py::init([](Workspace& w) {return new Workspace{w};})) + .def( + "__copy__", + [](Workspace& w) -> Workspace { return w; }, + py::is_operator()) + .def( + "__deepcopy__", + [](Workspace& w, py::dict&) { + return w.deepcopy(); + }, + py::is_operator()) .def("execute_controlfile", [](Workspace& w, const std::filesystem::path& path) { std::unique_ptr a{parse_agenda(w, diff --git a/src/workspace_ng.cc b/src/workspace_ng.cc index c1aadf6bfe..60f35ddc37 100644 --- a/src/workspace_ng.cc +++ b/src/workspace_ng.cc @@ -92,6 +92,9 @@ void Workspace::pop(Index i) { void Workspace::swap(Workspace &other) noexcept { ws.swap(other.ws); + wsv_data_ptr.swap(other.wsv_data_ptr); + WsvMap_ptr.swap(other.WsvMap_ptr); + std::swap(original_workspace, other.original_workspace); } bool Workspace::is_initialized(Index i) const { @@ -131,3 +134,49 @@ Workspace::Workspace() } } } + +std::shared_ptr Workspace::deepcopy() { + std::shared_ptr out{new Workspace{}}; + out->wsv_data_ptr = std::shared_ptr( + new Workspace::wsv_data_type{*wsv_data_ptr}); + out->WsvMap_ptr = std::shared_ptr( + new Workspace::WsvMap_type{*WsvMap_ptr}); + out->ws.resize(nelem()); + + for (Index i = 0; i < out->nelem(); i++) { + auto &wsv_data = out->wsv_data_ptr->operator[](i); + + if (depth(i) > 0) { + // Set the WSV by copying the top value + out->ws[i].emplace(WorkspaceVariableStruct{ + workspace_memory_handler.duplicate( + wsv_data_ptr->operator[](i).Group(), ws[i].top().wsv), + is_initialized(i)}); + + // Copy the agenda to the new workspace + if (wsv_data.Group() == WorkspaceGroupIndexValue) { + Agenda *ag = static_cast(out->operator[](i).get()); + *ag = ag->deepcopy_if(*out); + } else if (wsv_data.Group() == WorkspaceGroupIndexValue) { + for (auto &a : + *static_cast(out->operator[](i).get())) { + a = a.deepcopy_if(*out); + } + } + } + + // If we have any default agenda types, we must copy them to the new workspace as well + if (wsv_data.has_defaults()) { + if (wsv_data.Group() == WorkspaceGroupIndexValue) { + wsv_data.update_default_value( + Agenda(wsv_data.default_value()).deepcopy_if(*out)); + } + if (wsv_data.Group() == WorkspaceGroupIndexValue) { + wsv_data.update_default_value( + deepcopy_if(*out, wsv_data.default_value())); + } + } + } + + return out; +} diff --git a/src/workspace_ng.h b/src/workspace_ng.h index 366ce12c97..92848bdd31 100644 --- a/src/workspace_ng.h +++ b/src/workspace_ng.h @@ -55,9 +55,11 @@ class Workspace final : public std::enable_shared_from_this { /** Workspace variable container. */ Array ws; - std::shared_ptr> wsv_data_ptr; + using wsv_data_type = Array; + std::shared_ptr wsv_data_ptr; - std::shared_ptr> WsvMap_ptr; + using WsvMap_type = map; + std::shared_ptr WsvMap_ptr; Workspace* original_workspace; @@ -198,7 +200,11 @@ class Workspace final : public std::enable_shared_from_this { outstream << (*wsv_data_ptr)[i].Name() << "(" << i << ") "; } + //! Get a shared pointer to the object std::shared_ptr shared_ptr() {return shared_from_this();} + + //! Gets a full copy that owns all the data (only gets the top of the stack) + std::shared_ptr deepcopy(); }; template diff --git a/src/wsv_aux.h b/src/wsv_aux.h index e47be90917..618c6e23a5 100644 --- a/src/wsv_aux.h +++ b/src/wsv_aux.h @@ -102,6 +102,7 @@ class WsvRecord { [[nodiscard]] std::shared_ptr get_copy() const; [[nodiscard]] const TokVal& default_value() const { return defval; } + void update_default_value(ArtsType auto&& v) {defval=std::forward(v);} private: String mname;