In [None]:
SYFT_VERSION = ">=0.8.1b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
%pip install {package_string} -f https://whls.blob.core.windows.net/unstable/index.html -q

In [None]:
import syft as sy
sy.requires(SYFT_VERSION)

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port=8080, dev_mode=True, reset=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")
domain_client

In [None]:
from typing import List, Dict, Any, Optional

In [None]:
class RepeatedCallPolicy(sy.CustomOutputPolicy):
    n_calls: int = 0
    downloadable_output_args: List[str] = []
    state: Dict[Any, Any] = {}

    def __init__(self, n_calls=1, downloadable_output_args: List[str] = None):
        self.downloadable_output_args = downloadable_output_args if downloadable_output_args is not None else []
        self.n_calls = n_calls + 1
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]
        
    def apply_output(self, context, outputs):
        output_dict = {}
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = outputs[output_arg]

            self.state["counts"] += 1
        else:
            return None

        return output_dict

In [None]:
policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=['y'])

In [None]:
policy.n_calls

In [None]:
policy.downloadable_output_args

In [None]:
policy.init_kwargs

In [None]:
print(policy.init_kwargs)
a_obj = sy.ActionObject.from_obj({'y': [1,2,3]})
x = policy.apply_output(None, a_obj)
x

In [None]:
policy.n_calls

In [None]:
import numpy as np
x = np.array([1,2,3])
x_pointer = sy.ActionObject.from_obj(x)
domain_client.api.services.action.save(x_pointer)

In [None]:
@sy.syft_function(
    input_policy=sy.ExactMatch(x=x_pointer),
    output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=['y']),
)
def func(x):
    return {"y": x+1}

In [None]:
@sy.syft_function(input_policy=sy.ExactMatch(x=x_pointer),
                  output_policy=sy.SingleExecutionExactOutput())
def train_mlp(x):
    return x

In [None]:
domain_client.code.request_code_execution(func)

In [None]:
messages = domain_client.notifications.get_all_unread()
messages

In [None]:
domain_client.requests

In [None]:
request = domain_client.requests[0]
request

In [None]:
request.changes

In [None]:
func = request.changes[0].link
func

In [None]:
result = func.unsafe_function(x=x_pointer)
result

In [None]:
final_result = request.accept_by_depositing_result(result)
final_result

In [None]:
request.changes

In [None]:
assert request.changes[0].approved

In [None]:
res = domain_client.code.func(x=x_pointer)
res

In [None]:
res = res.get()
res

In [None]:
assert (res["y"] == np.array([2, 3, 4])).all()

In [None]:
assert set(res.keys()) == set(list("y"))

In [None]:
output_policy = domain_client.code.get_all()[0].output_policy
output_policy

In [None]:
domain_client.api.services.policy.get_all()

In [None]:
output_policy = domain_client.api.services.policy.get_all()
output_policy

In [None]:
node.land()

#### Verify policy is correctly loaded once code is approved

In [None]:
from syft.node.node import CODE_RELOADER
from syft.serde.recursive import TYPE_BANK
# clear any cached code reloader instances
CODE_RELOADER.clear()
assert len(CODE_RELOADER) == 0

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port=8080, dev_mode=True)

In [None]:
client = node.login(email="info@openmined.org", password="changethis")

In [None]:
assert bool(CODE_RELOADER)

In [None]:
output_policy = client.code.get_all()[0].output_policy
assert 'syft.user.' + str(output_policy.__class__.__name__) in TYPE_BANK

In [None]:
client.notifications

In [None]:
assert client.notifications

In [None]:
client.api.services.request

In [None]:
node.land()