Skip to content

Commit

Permalink
Merge pull request #520 from riclarsson/copy-workspace
Browse files Browse the repository at this point in the history
Beginning to deepcopy workspace
  • Loading branch information
riclarsson committed Sep 13, 2022
2 parents 86ed4df + 868eed6 commit 7c731ad
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 2 deletions.
6 changes: 6 additions & 0 deletions python/pyarts/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,9 @@ def __delattr__(self, attr):
raise AttributeError("You cannot delete __class__")

getattr(self, attr).delete_level()

def __copy__(self):
return Workspace(super().__copy__())

def __deepcopy__(self, *args):
return Workspace(super().__deepcopy__(*args))
52 changes: 52 additions & 0 deletions python/test/workspace/test_workspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Test handling of workspace of the Python interface.
"""
import copy
import pytest
import pyarts
from pyarts.workspace import Workspace, arts_agenda
from pyarts.arts import Index

class TestWorkspace:
def setup_method(self):
pass

def test_copy(self):
ws = Workspace()
ws.aaavar = Index(5)

ws2 = copy.copy(ws)
ws2.aaavar = 6

assert ws2.aaavar.value == ws.aaavar.value

def test_deepcopy(self):
ws = Workspace()
ws.aaavar = Index(5)

ws2 = copy.deepcopy(ws)
ws2.aaavar = 6

assert ws2.aaavar.value != ws.aaavar.value

def test_copy_agenda(self):
ws = Workspace()
ws.aaavar = Index(5)

@arts_agenda(ws=ws, set_agenda=True)
def test_agenda(ws):
ws.Print(ws.aaavar, 0)
ws.AgendaExecute(ws.test_agenda)

ws2 = copy.deepcopy(ws)
ws2.aaavar = 4
ws2.AgendaExecute(ws2.test_agenda)


if __name__ == "__main__":
ta = TestWorkspace()
ta.setup_method()
ta.test_copy()
ta.test_deepcopy()
ta.test_copy_agenda()

80 changes: 80 additions & 0 deletions src/agenda_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "global_data.h"
#include "messages.h"
#include "methods.h"
#include "tokval.h"
#include "workspace_ng.h"

MRecord::MRecord() : moutput(), minput(), msetvalue(), mtasks() {}
Expand Down Expand Up @@ -911,3 +912,82 @@ ostream& operator<<(ostream& os, const MRecord& a) {
a.print(os, "");
return os;
}

ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas) {
ArrayOfAgenda out;
for (auto& ag : agendas) out.push_back(ag.deepcopy_if(ws));
return out;
}

// Method to share indices of variables from one workspace to another
ArrayOfIndex make_same_wsvs(Workspace& ws_out,
const Workspace& ws_in,
const ArrayOfIndex& vars) {
ArrayOfIndex out;

out.reserve(vars.size());
for (auto& v : vars) {

// Set the value position
if (v < ws_out.nelem()) {
if (ws_in.wsv_data_ptr->at(v).Name() ==
ws_out.wsv_data_ptr->at(v).Name() and
ws_in.wsv_data_ptr->at(v).Group() ==
ws_out.wsv_data_ptr->at(v).Group()) {
out.push_back(v); // This is the only tested path!
} else {
out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v)));
}
} else {
out.push_back(ws_out.add_wsv(ws_in.wsv_data_ptr->at(v)));
}

// Update if the wsv holds an agenda default value
if (auto& wsv = ws_out.wsv_data_ptr->at(out.back()); wsv.has_defaults()) {
auto& val = wsv.default_value();
if (wsv.Group() == WorkspaceGroupIndexValue<Agenda>)
wsv.update_default_value(Agenda(val).deepcopy_if(ws_out));
else if (wsv.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>)
wsv.update_default_value(deepcopy_if(ws_out, val));
}
}

return out;
}

MRecord MRecord::deepcopy_if(Workspace& workspace) const {
if (mtasks.has_same_origin(workspace)) return *this;

MRecord out(workspace);
out.mid = mid;
out.moutput = make_same_wsvs(workspace, mtasks.workspace(), moutput);
out.minput = make_same_wsvs(workspace, mtasks.workspace(), minput);

// Must update if the value is an agenda
if (msetvalue.holdsAgenda()) {
out.msetvalue = Agenda(msetvalue).deepcopy_if(workspace);
} else if (msetvalue.holdsArrayOfAgenda()) {
out.msetvalue = ::deepcopy_if(workspace, msetvalue);
} else {
out.msetvalue = msetvalue;
}

out.mtasks = mtasks.deepcopy_if(workspace);
out.minternal = minternal;

return out;
}

Agenda Agenda::deepcopy_if(Workspace& workspace) const {
if (has_same_origin(workspace)) return *this;

Agenda out(workspace);
out.mname = mname;
for (auto& method : mml) out.mml.push_back(method.deepcopy_if(workspace));
out.moutput_push = make_same_wsvs(workspace, *ws, moutput_push);
out.moutput_dup = make_same_wsvs(workspace, *ws, moutput_dup);
out.main_agenda = main_agenda;
out.mchecked = mchecked;

return out;
}
9 changes: 9 additions & 0 deletions src/agenda_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Agenda final {
[[nodiscard]] Workspace& workspace() {return *ws;}
[[nodiscard]] const Workspace& workspace() const {return *ws;}

//! Creates a deep copy of the agenda if necessary (i.e., different workspace)!
Agenda deepcopy_if(Workspace&) const;

private:
std::shared_ptr<Workspace> ws; /*!< The workspace upon which this Agenda lives. */
String mname; /*!< Agenda name. */
Expand Down Expand Up @@ -153,6 +156,9 @@ class MRecord {
[[nodiscard]] const TokVal& SetValue() const { return msetvalue; }
[[nodiscard]] const Agenda& Tasks() const { return mtasks; }

//! Creates a deep copy of the method if necessary (i.e., different workspace)!
MRecord deepcopy_if(Workspace&) const;

//! Indicates the origin of this method.
/*!
Returns true if this method originates from a controlfile and false
Expand Down Expand Up @@ -222,4 +228,7 @@ class MRecord {
/** An array of Agenda. */
using ArrayOfAgenda = Array<Agenda>;

//! Same as Agenda member method but for an entire array
ArrayOfAgenda deepcopy_if(Workspace& ws, const ArrayOfAgenda& agendas);

#endif
11 changes: 11 additions & 0 deletions src/make_tokval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ concept ArtsTypeConstRef = ArtsType<T> and std::is_same_v<std::add_const_t<std::
template <typename T>
concept ArtsTypeBase = ArtsType<T> and std::is_same_v<std::remove_cvref_t<T>, T>;
template <ArtsTypeBase> struct WorkspaceGroupIndex { static constexpr Index value=-1; };
)--";

for (Index i = 0; i < global_data::wsv_groups.nelem(); i++)
file_h << "template <> struct WorkspaceGroupIndex<"
<< global_data::wsv_groups[i]
<< "> { static constexpr Index value=" << i << "; };\n";

file_h << R"--(
template <ArtsTypeBase T> inline constexpr Index WorkspaceGroupIndexValue = WorkspaceGroupIndex<T>::value;
class TokVal {
void * ptr{nullptr};
public:
Expand Down
11 changes: 11 additions & 0 deletions src/python_interface/py_workspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ void py_workspace(py::module_& m,
}),
py::arg("verbosity") = 0,
py::arg("agenda_verbosity") = 0)
.def(py::init([](Workspace& w) {return new Workspace{w};}))
.def(
"__copy__",
[](Workspace& w) -> Workspace { return w; },
py::is_operator())
.def(
"__deepcopy__",
[](Workspace& w, py::dict&) {
return w.deepcopy();
},
py::is_operator())
.def("execute_controlfile",
[](Workspace& w, const std::filesystem::path& path) {
std::unique_ptr<Agenda> a{parse_agenda(w,
Expand Down
49 changes: 49 additions & 0 deletions src/workspace_ng.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ void Workspace::pop(Index i) {

void Workspace::swap(Workspace &other) noexcept {
ws.swap(other.ws);
wsv_data_ptr.swap(other.wsv_data_ptr);
WsvMap_ptr.swap(other.WsvMap_ptr);
std::swap(original_workspace, other.original_workspace);
}

bool Workspace::is_initialized(Index i) const {
Expand Down Expand Up @@ -131,3 +134,49 @@ Workspace::Workspace()
}
}
}

std::shared_ptr<Workspace> Workspace::deepcopy() {
std::shared_ptr<Workspace> out{new Workspace{}};
out->wsv_data_ptr = std::shared_ptr<Workspace::wsv_data_type>(
new Workspace::wsv_data_type{*wsv_data_ptr});
out->WsvMap_ptr = std::shared_ptr<Workspace::WsvMap_type>(
new Workspace::WsvMap_type{*WsvMap_ptr});
out->ws.resize(nelem());

for (Index i = 0; i < out->nelem(); i++) {
auto &wsv_data = out->wsv_data_ptr->operator[](i);

if (depth(i) > 0) {
// Set the WSV by copying the top value
out->ws[i].emplace(WorkspaceVariableStruct{
workspace_memory_handler.duplicate(
wsv_data_ptr->operator[](i).Group(), ws[i].top().wsv),
is_initialized(i)});

// Copy the agenda to the new workspace
if (wsv_data.Group() == WorkspaceGroupIndexValue<Agenda>) {
Agenda *ag = static_cast<Agenda *>(out->operator[](i).get());
*ag = ag->deepcopy_if(*out);
} else if (wsv_data.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>) {
for (auto &a :
*static_cast<ArrayOfAgenda *>(out->operator[](i).get())) {
a = a.deepcopy_if(*out);
}
}
}

// If we have any default agenda types, we must copy them to the new workspace as well
if (wsv_data.has_defaults()) {
if (wsv_data.Group() == WorkspaceGroupIndexValue<Agenda>) {
wsv_data.update_default_value(
Agenda(wsv_data.default_value()).deepcopy_if(*out));
}
if (wsv_data.Group() == WorkspaceGroupIndexValue<ArrayOfAgenda>) {
wsv_data.update_default_value(
deepcopy_if(*out, wsv_data.default_value()));
}
}
}

return out;
}
10 changes: 8 additions & 2 deletions src/workspace_ng.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ class Workspace final : public std::enable_shared_from_this<Workspace> {
/** Workspace variable container. */
Array<WorkspaceVariable> ws;

std::shared_ptr<Array<WsvRecord>> wsv_data_ptr;
using wsv_data_type = Array<WsvRecord>;
std::shared_ptr<wsv_data_type> wsv_data_ptr;

std::shared_ptr<map<String, Index>> WsvMap_ptr;
using WsvMap_type = map<String, Index>;
std::shared_ptr<WsvMap_type> WsvMap_ptr;

Workspace* original_workspace;

Expand Down Expand Up @@ -198,7 +200,11 @@ class Workspace final : public std::enable_shared_from_this<Workspace> {
outstream << (*wsv_data_ptr)[i].Name() << "(" << i << ") ";
}

//! Get a shared pointer to the object
std::shared_ptr<Workspace> shared_ptr() {return shared_from_this();}

//! Gets a full copy that owns all the data (only gets the top of the stack)
std::shared_ptr<Workspace> deepcopy();
};

template <class T>
Expand Down
1 change: 1 addition & 0 deletions src/wsv_aux.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class WsvRecord {
[[nodiscard]] std::shared_ptr<void> get_copy() const;

[[nodiscard]] const TokVal& default_value() const { return defval; }
void update_default_value(ArtsType auto&& v) {defval=std::forward<decltype(v)>(v);}

private:
String mname;
Expand Down

0 comments on commit 7c731ad

Please sign in to comment.