In [None]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [None]:
# third party
import numpy as np

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port="auto", dev_mode=True, reset=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
domain_client.register(
    email="newuser@openmined.org", name="John Doe", password="pw", password_verify="pw"
)

In [None]:
client_low_ds = node.login(email="newuser@openmined.org", password="pw")

In [None]:
# stdlib
from typing import Any

In [None]:
class RepeatedCallPolicy(sy.CustomOutputPolicy):
    n_calls: int = 0
    downloadable_output_args: list[str] = []
    state: dict[Any, Any] = {}

    def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
        self.downloadable_output_args = (
            downloadable_output_args if downloadable_output_args is not None else []
        )
        self.n_calls = n_calls
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]

    def update_policy(self, context, outputs):
        self.state["counts"] += 1

    def apply_to_output(self, context, outputs, update_policy=True):
        if hasattr(outputs, "syft_action_data"):
            outputs = outputs.syft_action_data
        output_dict = {}
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = outputs[output_arg]
            if update_policy:
                self.update_policy(context, outputs)
        else:
            return None
        return output_dict

    def _is_valid(self, context):
        return self.state["counts"] < self.n_calls

In [None]:
policy = RepeatedCallPolicy(n_calls=1, downloadable_output_args=["y"])

In [None]:
policy.n_calls

In [None]:
policy.downloadable_output_args

In [None]:
policy.init_kwargs

In [None]:
print(policy.init_kwargs)
a_obj = sy.ActionObject.from_obj({"y": [1, 2, 3]})
x = policy.apply_to_output(None, a_obj)
x["y"]

In [None]:
policy.n_calls

In [None]:
x = np.array([1, 2, 3])
x_pointer = sy.ActionObject.from_obj(x)
x_pointer

In [None]:
x_pointer = x_pointer.send(domain_client)

In [None]:
# third party
from result import Err
from result import Ok

# syft absolute
from syft.client.api import AuthedServiceContext
from syft.client.api import NodeIdentity


class CustomExactMatch(sy.CustomInputPolicy):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        pass

    def filter_kwargs(self, kwargs, context, code_item_id):
        # stdlib

        try:
            allowed_inputs = self.allowed_ids_only(
                allowed_inputs=self.inputs, kwargs=kwargs, context=context
            )
            results = self.retrieve_from_db(
                code_item_id=code_item_id,
                allowed_inputs=allowed_inputs,
                context=context,
            )
        except Exception as e:
            return Err(str(e))
        return results

    def retrieve_from_db(self, code_item_id, allowed_inputs, context):
        # syft absolute
        from syft import NodeType
        from syft.service.action.action_object import TwinMode

        action_service = context.node.get_service("actionservice")
        code_inputs = {}

        # When we are retrieving the code from the database, we need to use the node's
        # verify key as the credentials. This is because when we approve the code, we
        # we allow the private data to be used only for this specific code.
        # but we are not modifying the permissions of the private data

        root_context = AuthedServiceContext(
            node=context.node, credentials=context.node.verify_key
        )
        if context.node.node_type == NodeType.DOMAIN:
            for var_name, arg_id in allowed_inputs.items():
                kwarg_value = action_service._get(
                    context=root_context,
                    uid=arg_id,
                    twin_mode=TwinMode.NONE,
                    has_permission=True,
                )
                if kwarg_value.is_err():
                    return Err(kwarg_value.err())
                code_inputs[var_name] = kwarg_value.ok()
        else:
            raise Exception(
                f"Invalid Node Type for Code Submission:{context.node.node_type}"
            )
        return Ok(code_inputs)

    def allowed_ids_only(
        self,
        allowed_inputs,
        kwargs,
        context,
    ):
        # syft absolute
        from syft import NodeType
        from syft import UID

        if context.node.node_type == NodeType.DOMAIN:
            node_identity = NodeIdentity(
                node_name=context.node.name,
                node_id=context.node.id,
                verify_key=context.node.signing_key.verify_key,
            )
            allowed_inputs = allowed_inputs.get(node_identity, {})
        else:
            raise Exception(
                f"Invalid Node Type for Code Submission:{context.node.node_type}"
            )
        filtered_kwargs = {}
        for key in allowed_inputs.keys():
            if key in kwargs:
                value = kwargs[key]
                uid = value
                if not isinstance(uid, UID):
                    uid = getattr(value, "id", None)

                if uid != allowed_inputs[key]:
                    raise Exception(
                        f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                    )
                filtered_kwargs[key] = value
        return filtered_kwargs

    def _is_valid(
        self,
        context,
        usr_input_kwargs,
        code_item_id,
    ):
        filtered_input_kwargs = self.filter_kwargs(
            kwargs=usr_input_kwargs,
            context=context,
            code_item_id=code_item_id,
        )

        if filtered_input_kwargs.is_err():
            return filtered_input_kwargs

        filtered_input_kwargs = filtered_input_kwargs.ok()

        expected_input_kwargs = set()
        for _inp_kwargs in self.inputs.values():
            for k in _inp_kwargs.keys():
                if k not in usr_input_kwargs:
                    return Err(f"Function missing required keyword argument: '{k}'")
            expected_input_kwargs.update(_inp_kwargs.keys())

        permitted_input_kwargs = list(filtered_input_kwargs.keys())
        not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)
        if len(not_approved_kwargs) > 0:
            return Err(
                f"Input arguments: {not_approved_kwargs} to the function are not approved yet."
            )
        return Ok(True)


def allowed_ids_only(
    self,
    allowed_inputs,
    kwargs,
    context,
):
    # syft absolute
    from syft import NodeType
    from syft import UID
    from syft.client.api import NodeIdentity

    if context.node.node_type == NodeType.DOMAIN:
        node_identity = NodeIdentity(
            node_name=context.node.name,
            node_id=context.node.id,
            verify_key=context.node.signing_key.verify_key,
        )
        allowed_inputs = allowed_inputs.get(node_identity, {})
    else:
        raise Exception(
            f"Invalid Node Type for Code Submission:{context.node.node_type}"
        )
    filtered_kwargs = {}
    for key in allowed_inputs.keys():
        if key in kwargs:
            value = kwargs[key]
            uid = value
            if not isinstance(uid, UID):
                uid = getattr(value, "id", None)

            if uid != allowed_inputs[key]:
                raise Exception(
                    f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                )
            filtered_kwargs[key] = value
    return filtered_kwargs

In [None]:
@sy.syft_function(
    input_policy=CustomExactMatch(x=x_pointer),
    output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=["y"]),
)
def func(x):
    return {"y": x + 1}

In [None]:
request = client_low_ds.code.request_code_execution(func)
request

In [None]:
request_id = request.id

In [None]:
client_low_ds.code.get_all()

In [None]:
for request in domain_client.requests:
    if request.id == request_id:
        break

In [None]:
func = request.code

In [None]:
# Custom policies need to be approved before they can be viewed and used
assert func.input_policy is None
assert func.output_policy is None

In [None]:
result = func.run(x=x_pointer)
result

In [None]:
request.approve()

In [None]:
assert func.input_policy is not None
assert func.output_policy is not None

In [None]:
res_ptr = client_low_ds.code.func(x=x_pointer)
res_ptr

In [None]:
res = res_ptr.get()
res

In [None]:
assert (res["y"] == np.array([2, 3, 4])).all()

In [None]:
assert set(res.keys()) == set("y")

In [None]:
for code in domain_client.code.get_all():
    if code.service_func_name == "func":
        break
print(code.output_policy.state)
assert code.output_policy.state == {"counts": 1}

In [None]:
if node.node_type.value == "python":
    node.land()