Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beginning to deepcopy workspace #520

Merged
merged 6 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/pyarts/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,9 @@ def __delattr__(self, attr):
raise AttributeError("You cannot delete __class__")

getattr(self, attr).delete_level()

def __copy__(self):
return Workspace(super().__copy__())

def __deepcopy__(self, *args):
return Workspace(super().__deepcopy__(*args))
52 changes: 52 additions & 0 deletions python/test/workspace/test_workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Test handling of workspace of the Python interface.
"""
import copy
import pytest
import pyarts
from pyarts.workspace import Workspace, arts_agenda
from pyarts.arts import Index

class TestWorkspace:
def setup_method(self):
pass

def test_copy(self):
ws = Workspace()
ws.aaavar = Index(5)

ws2 = copy.copy(ws)
ws2.aaavar = 6

assert ws2.aaavar.value == ws.aaavar.value

def test_deepcopy(self):
ws = Workspace()
ws.aaavar = Index(5)

ws2 = copy.deepcopy(ws)
ws2.aaavar = 6

assert ws2.aaavar.value != ws.aaavar.value

def test_copy_agenda(self):
ws = Workspace()
ws.aaavar = Index(5)

@arts_agenda(ws=ws, set_agenda=True)
def test_agenda(ws):
ws.Print(ws.aaavar, 0)
ws.AgendaExecute(ws.test_agenda)

ws2 = copy.deepcopy(ws)
ws2.aaavar = 4
ws2.AgendaExecute(ws2.test_agenda)


if __name__ == "__main__":
ta = TestWorkspace()
ta.setup_method()
ta.test_copy()
ta.test_deepcopy()
ta.test_copy_agenda()

80 changes: 80 additions & 0 deletions src/agenda_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "global_data.h"
#include "messages.h"
#include "methods.h"
#include "tokval.h"
#include "workspace_ng.h"

MRecord::MRecord() : moutput(), minput(), msetvalue(), mtasks() {}
Expand Down Expand Up @@ -911,3 +912,82 @@ ostream& operator<<(ostream& os, const MRecord& a) {
a.print(os, "");
return os;
}

ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas) {
ArrayOfAgenda out;
for (auto& ag : agendas) out.push_back(ag.deepcopy_if(ws));
return out;
}

// Method to share indices of variables from one workspace to another
ArrayOfIndex make_same_wsvs(Workspace& ws_out,
const Workspace& ws_in,
const ArrayOfIndex& vars) {
ArrayOfIndex out;

out.reserve(vars.size());
for (auto& v : vars) {

// Set the value position
if (v < ws_out.nelem()) {
if (ws_in.wsv_data_ptr->at(v).Name() ==
ws_out.wsv_data_ptr->at(v).Name() and
ws_in.wsv_data_ptr->at(v).Group() ==
ws_out.wsv_data_ptr->at(v).Group()) {
out.push_back(v); // This is the only tested path!
} else {
out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v)));
}
} else {
out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v)));
}

// Update if the wsv holds an agenda default value
if (auto& wsv = ws_out.wsv_data_ptr->at(out.back()); wsv.has_defaults()) {
auto& val = wsv.default_value();
if (wsv.Group() == WorkspaceGroupIndexValue<Agenda>)
wsv.update_default_value(Agenda(val).deepcopy_if(ws_out));
else if (wsv.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>)
wsv.update_default_value(deepcopy_if(ws_out, val));
}
}

return out;
}

MRecord MRecord::deepcopy_if(Workspace& workspace) const {
if (mtasks.has_same_origin(workspace)) return *this;

MRecord out(workspace);
out.mid = mid;
out.moutput = make_same_wsvs(workspace, mtasks.workspace(), moutput);
out.minput = make_same_wsvs(workspace, mtasks.workspace(), minput);

// Must update if the value is an agenda
if (msetvalue.holdsAgenda()) {
out.msetvalue = Agenda(msetvalue).deepcopy_if(workspace);
} else if (msetvalue.holdsArrayOfAgenda()) {
out.msetvalue = ::deepcopy_if(workspace, msetvalue);
} else {
out.msetvalue = msetvalue;
}

out.mtasks = mtasks.deepcopy_if(workspace);
out.minternal = minternal;

return out;
}

Agenda Agenda::deepcopy_if(Workspace& workspace) const {
if (has_same_origin(workspace)) return *this;

Agenda out(workspace);
out.mname = mname;
for (auto& method : mml) out.mml.push_back(method.deepcopy_if(workspace));
out.moutput_push = make_same_wsvs(workspace, *ws, moutput_push);
out.moutput_dup = make_same_wsvs(workspace, *ws, moutput_dup);
out.main_agenda = main_agenda;
out.mchecked = mchecked;

return out;
}
9 changes: 9 additions & 0 deletions src/agenda_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Agenda final {
[[nodiscard]] Workspace& workspace() {return *ws;}
[[nodiscard]] const Workspace& workspace() const {return *ws;}

//! Creates a deep copy of the agenda if necessary (i.e., different workspace)!
Agenda deepcopy_if(Workspace&) const;

private:
std::shared_ptr<Workspace> ws; /*!< The workspace upon which this Agenda lives. */
String mname; /*!< Agenda name. */
Expand Down Expand Up @@ -153,6 +156,9 @@ class MRecord {
[[nodiscard]] const TokVal& SetValue() const { return msetvalue; }
[[nodiscard]] const Agenda& Tasks() const { return mtasks; }

//! Creates a deep copy of the method if necessary (i.e., different workspace)!
MRecord deepcopy_if(Workspace&) const;

//! Indicates the origin of this method.
/*!
Returns true if this method originates from a controlfile and false
Expand Down Expand Up @@ -222,4 +228,7 @@ class MRecord {
/** An array of Agenda. */
using ArrayOfAgenda = Array<Agenda>;

//! Same as Agenda member method but for an entire array
ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas);

#endif
11 changes: 11 additions & 0 deletions src/make_tokval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ concept ArtsTypeConstRef = ArtsType<T> and std::is_same_v<std::add_const_t<std::
template <typename T>
concept ArtsTypeBase = ArtsType<T> and std::is_same_v<std::remove_cvref_t<T>, T>;

template <ArtsTypeBase> struct WorkspaceGroupIndex { static constexpr Index value=-1; };
)--";

for (Index i = 0; i < global_data::wsv_groups.nelem(); i++)
file_h << "template <> struct WorkspaceGroupIndex<"
<< global_data::wsv_groups[i]
<< "> { static constexpr Index value=" << i << "; };\n";

file_h << R"--(
template <ArtsTypeBase T> inline constexpr Index WorkspaceGroupIndexValue = WorkspaceGroupIndex<T>::value;

class TokVal {
void * ptr{nullptr};
public:
Expand Down
11 changes: 11 additions & 0 deletions src/python_interface/py_workspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ void py_workspace(py::module_& m,
}),
py::arg("verbosity") = 0,
py::arg("agenda_verbosity") = 0)
.def(py::init([](Workspace& w) {return new Workspace{w};}))
.def(
"__copy__",
[](Workspace& w) -> Workspace { return w; },
py::is_operator())
.def(
"__deepcopy__",
[](Workspace& w, py::dict&) {
return w.deepcopy();
},
py::is_operator())
.def("execute_controlfile",
[](Workspace& w, const std::filesystem::path& path) {
std::unique_ptr<Agenda> a{parse_agenda(w,
Expand Down
49 changes: 49 additions & 0 deletions src/workspace_ng.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ void Workspace::pop(Index i) {

void Workspace::swap(Workspace &other) noexcept {
ws.swap(other.ws);
wsv_data_ptr.swap(other.wsv_data_ptr);
WsvMap_ptr.swap(other.WsvMap_ptr);
std::swap(original_workspace, other.original_workspace);
}

bool Workspace::is_initialized(Index i) const {
Expand Down Expand Up @@ -131,3 +134,49 @@ Workspace::Workspace()
}
}
}

std::shared_ptr<Workspace> Workspace::deepcopy() {
std::shared_ptr<Workspace> out{new Workspace{}};
out->wsv_data_ptr = std::shared_ptr<Workspace::wsv_data_type>(
new Workspace::wsv_data_type{*wsv_data_ptr});
out->WsvMap_ptr = std::shared_ptr<Workspace::WsvMap_type>(
new Workspace::WsvMap_type{*WsvMap_ptr});
out->ws.resize(nelem());

for (Index i = 0; i < out->nelem(); i++) {
auto &wsv_data = out->wsv_data_ptr->operator[](i);

if (depth(i) > 0) {
// Set the WSV by copying the top value
out->ws[i].emplace(WorkspaceVariableStruct{
workspace_memory_handler.duplicate(
wsv_data_ptr->operator[](i).Group(), ws[i].top().wsv),
is_initialized(i)});

// Copy the agenda to the new workspace
if (wsv_data.Group() == WorkspaceGroupIndexValue<Agenda>) {
Agenda *ag = static_cast<Agenda *>(out->operator[](i).get());
*ag = ag->deepcopy_if(*out);
} else if (wsv_data.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>) {
for (auto &a :
*static_cast<ArrayOfAgenda *>(out->operator[](i).get())) {
a = a.deepcopy_if(*out);
}
}
}

// If we have any default agenda types, we must copy them to the new workspace as well
if (wsv_data.has_defaults()) {
if (wsv_data.Group() == WorkspaceGroupIndexValue<Agenda>) {
wsv_data.update_default_value(
Agenda(wsv_data.default_value()).deepcopy_if(*out));
}
if (wsv_data.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>) {
wsv_data.update_default_value(
deepcopy_if(*out, wsv_data.default_value()));
}
}
}

return out;
}
10 changes: 8 additions & 2 deletions src/workspace_ng.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ class Workspace final : public std::enable_shared_from_this<Workspace> {
/** Workspace variable container. */
Array<WorkspaceVariable> ws;

std::shared_ptr<Array<WsvRecord>> wsv_data_ptr;
using wsv_data_type = Array<WsvRecord>;
std::shared_ptr<wsv_data_type> wsv_data_ptr;

std::shared_ptr<map<String, Index>> WsvMap_ptr;
using WsvMap_type = map<String, Index>;
std::shared_ptr<WsvMap_type> WsvMap_ptr;

Workspace* original_workspace;

Expand Down Expand Up @@ -198,7 +200,11 @@ class Workspace final : public std::enable_shared_from_this<Workspace> {
outstream << (*wsv_data_ptr)[i].Name() << "(" << i << ") ";
}

//! Get a shared pointer to the object
std::shared_ptr<Workspace> shared_ptr() {return shared_from_this();}

//! Gets a full copy that owns all the data (only gets the top of the stack)
std::shared_ptr<Workspace> deepcopy();
};

template <class T>
Expand Down
1 change: 1 addition & 0 deletions src/wsv_aux.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class WsvRecord {
[[nodiscard]] std::shared_ptr<void> get_copy() const;

[[nodiscard]] const TokVal& default_value() const { return defval; }
void update_default_value(ArtsType auto&& v) {defval=std::forward<decltype(v)>(v);}

private:
String mname;
Expand Down