In [None]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [None]:
# third party
import numpy as np

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port="auto", dev_mode=True, reset=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# stdlib
from typing import Any

In [None]:
class RepeatedCallPolicy(sy.CustomOutputPolicy):
    n_calls: int = 0
    downloadable_output_args: list[str] = []
    state: dict[Any, Any] = {}

    def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
        self.downloadable_output_args = (
            downloadable_output_args if downloadable_output_args is not None else []
        )
        self.n_calls = n_calls + 1
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]

    def apply_output(self, context, outputs):
        if hasattr(outputs, "syft_action_data"):
            outputs = outputs.syft_action_data
        output_dict = {}
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = outputs[output_arg]

            self.state["counts"] += 1
        else:
            return None

        return output_dict

In [None]:
policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=["y"])

In [None]:
policy.n_calls

In [None]:
policy.downloadable_output_args

In [None]:
policy.init_kwargs

In [None]:
print(policy.init_kwargs)
a_obj = sy.ActionObject.from_obj({"y": [1, 2, 3]})
x = policy.apply_output(None, a_obj)
x["y"]

In [None]:
policy.n_calls

In [None]:
x = np.array([1, 2, 3])
x_pointer = sy.ActionObject.from_obj(x)
x_pointer

In [None]:
domain_client.api.services.action.set(x_pointer)

In [None]:
@sy.syft_function(
    input_policy=sy.ExactMatch(x=x_pointer),
    output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=["y"]),
)
def func(x):
    return {"y": x + 1}

In [None]:
request = domain_client.code.request_code_execution(func)
request

In [None]:
domain_client.code.get_all()

In [None]:
func = request.code

In [None]:
result = func.unsafe_function(x=x_pointer)
result

In [None]:
final_result = request.accept_by_depositing_result(result)
final_result

In [None]:
res_ptr = domain_client.code.func(x=x_pointer)
res_ptr

In [None]:
res = res_ptr.get()
res

In [None]:
assert (res["y"] == np.array([2, 3, 4])).all()

In [None]:
assert set(res.keys()) == set("y")

In [None]:
domain_client.code.get_all()[0].output_policy

In [None]:
domain_client.api.services.policy.get_all()

In [None]:
output_policy = domain_client.api.services.policy.get_all()
output_policy

In [None]:
if node.node_type.value == "python":
    node.land()