In [None]:
import syft as sy
sy.requires(">=0.8-beta")

In [None]:
node = sy.orchestra.launch(name="test-domain-1", processes=1, reset=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
from syft.core.node.new.new_policy import CustomOutputPolicy
@sy.serializable()
class RepeatedCallPolicy(CustomOutputPolicy):
    __canonical_name__ = "RepeatedCallPolicy"
    from typing import List, Dict, Any

    n_calls: int
    downloadable_output_args: List[str]
    state: Dict[Any, Any] = {}
    
    __attr_allowlist__ = [
        "n_calls",
        "downloadable_output_args",
    ]

    def __init__(self, n_calls=1, downloadable_output_args=None):
        super().__init__(n_calls=n_calls, downloadable_output_args=downloadable_output_args)
        self.n_calls = n_calls + 1
        self.downloadable_output_args = downloadable_output_args if not None else []
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]
        
    def apply_output(self, context, results):
        output_dict = {}
        results_dict = results.syft_action_data        
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = results_dict[output_arg]

            self.state["counts"] += 1
        else:
            return None

        return output_dict

In [None]:
policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=['y'])
a_obj = sy.ActionObject.from_obj({'y': [1,2,3]})
policy.apply_output(None, a_obj)

In [None]:
import numpy as np
x = np.array([1,2,3])
x_pointer = sy.ActionObject.from_obj(x)
domain_client.api.services.action.save(x_pointer)

In [None]:
obj = RepeatedCallPolicy(n_calls=1, downloadable_output_args=['y'])
obj

In [None]:
obj

In [None]:
obj.init_kwargs

In [None]:
from syft.core.node.new.new_policy import ExactMatch
@sy.syft_function(
    input_policy=ExactMatch(x=x_pointer),
    output_policy=RepeatedCallPolicy(n_calls=1, downloadable_output_args=['y']),
)
def custom_func(x):
    return {"y": x+1}

In [None]:
domain_client.api.services.code.request_code_execution(custom_func)

In [None]:
request = domain_client.notifications[-1].link
change = func = request.changes[-1]
change

In [None]:
request.approve()

In [None]:
func = request.changes[-1].link
func

In [None]:
result = func.unsafe_function(x=x_pointer)
result

In [None]:
result

In [None]:
final_result = request.accept_by_depositing_result(result) 
final_result

In [None]:
res = domain_client.api.services.code.custom_func(x=x_pointer)
res

In [None]:
assert (res["y"] == np.array([2, 3, 4])).all()

In [None]:
res

In [None]:
res = dict(res.syft_action_data)

In [None]:
assert set(res.keys()) == set(list("y"))

In [None]:
domain_client.api.services.code.get_all()

In [None]:
domain_client.api.services.policy.get_all()

In [None]:
output_policy = domain_client.api.services.policy.get_all()

In [None]:
sy.orchestra.land("test-domain-1")