In [None]:
# -- browsing datasets
# -- getting a pointer
# -- mock vs private
# -- Pointer UIDs
# -- choosing an input policy
# -- choosing an output policy
# -- using the syft function decorator
# -- testing code locally
# -- submitting code for approval
# -- code is denied
# -- changing code and re-uploading a new version

In [None]:
# syft absolute
import syft as sy

In [None]:
server = sy.orchestra.launch(name="test-datasite-1", port=8081)

In [None]:
admin_client = server.login(email="info@openmined.org", password="changethis")
user_client = server.login(email="scientist@test.com", password="123")

In [None]:
user_client.datasets

In [None]:
# Not sure about getting a pointer, what needs to be added?

In [None]:
user_client.datasets[0]

In [None]:
user_client.datasets[0].assets[0]

In [None]:
mock_data = user_client.datasets[0].assets[0].mock
mock_data

In [None]:
private_data = user_client.datasets[0].assets[0].data
private_data

# Printing this because mock and private data are completely different

In [None]:
private_data = admin_client.datasets[0].assets[0].data
private_data

# Standard and custom Input/Output Policies and syft function decorator

In [None]:
asset = user_client.datasets[0].assets[0]

In [None]:
@sy.syft_function_single_use(ages_data=asset)
def how_are_people_dying_statistics(ages_data):
    df = ages_data
    avg_age_death_gender = (
        df.groupby("Gender")["Age of death"].mean().reset_index(name="Avg_Age_of_Death")
    )
    manner_of_death_count = (
        df.groupby("Manner of death")
        .size()
        .reset_index(name="Count")
        .sort_values(by="Count", ascending=False)
    )

    return (manner_of_death_count, avg_age_death_gender)

In [None]:
# stdlib
from typing import Any

# third party
from result import Err
from result import Ok

# syft absolute
from syft.client.api import AuthedServiceContext
from syft.client.api import ServerIdentity


class CustomExactMatch(sy.CustomInputPolicy):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        pass

    def filter_kwargs(self, kwargs, context, code_item_id):
        # stdlib

        try:
            allowed_inputs = self.allowed_ids_only(
                allowed_inputs=self.inputs, kwargs=kwargs, context=context
            )
            results = self.retrieve_from_db(
                code_item_id=code_item_id,
                allowed_inputs=allowed_inputs,
                context=context,
            )
        except Exception as e:
            return Err(str(e))
        return results

    def retrieve_from_db(self, code_item_id, allowed_inputs, context):
        # syft absolute
        from syft import ServerType
        from syft.service.action.action_object import TwinMode

        action_service = context.server.get_service("actionservice")
        code_inputs = {}

        # When we are retrieving the code from the database, we need to use the server's
        # verify key as the credentials. This is because when we approve the code, we
        # we allow the private data to be used only for this specific code.
        # but we are not modifying the permissions of the private data

        root_context = AuthedServiceContext(
            server=context.server, credentials=context.server.verify_key
        )
        if context.server.server_type == ServerType.DATASITE:
            for var_name, arg_id in allowed_inputs.items():
                kwarg_value = action_service._get(
                    context=root_context,
                    uid=arg_id,
                    twin_mode=TwinMode.NONE,
                    has_permission=True,
                )
                if kwarg_value.is_err():
                    return Err(kwarg_value.err())
                code_inputs[var_name] = kwarg_value.ok()
        else:
            raise Exception(
                f"Invalid Server Type for Code Submission:{context.server.server_type}"
            )
        return Ok(code_inputs)

    def allowed_ids_only(
        self,
        allowed_inputs,
        kwargs,
        context,
    ):
        # syft absolute
        from syft import ServerType
        from syft import UID

        if context.server.server_type == ServerType.DATASITE:
            server_identity = ServerIdentity(
                server_name=context.server.name,
                server_id=context.server.id,
                verify_key=context.server.signing_key.verify_key,
            )
            allowed_inputs = allowed_inputs.get(server_identity, {})
        else:
            raise Exception(
                f"Invalid Server Type for Code Submission:{context.server.server_type}"
            )
        filtered_kwargs = {}
        for key in allowed_inputs.keys():
            if key in kwargs:
                value = kwargs[key]
                uid = value
                if not isinstance(uid, UID):
                    uid = getattr(value, "id", None)

                if uid != allowed_inputs[key]:
                    raise Exception(
                        f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                    )
                filtered_kwargs[key] = value
        return filtered_kwargs

    def _is_valid(
        self,
        context,
        usr_input_kwargs,
        code_item_id,
    ):
        filtered_input_kwargs = self.filter_kwargs(
            kwargs=usr_input_kwargs,
            context=context,
            code_item_id=code_item_id,
        )

        if filtered_input_kwargs.is_err():
            return filtered_input_kwargs

        filtered_input_kwargs = filtered_input_kwargs.ok()

        expected_input_kwargs = set()
        for _inp_kwargs in self.inputs.values():
            for k in _inp_kwargs.keys():
                if k not in usr_input_kwargs:
                    return Err(f"Function missing required keyword argument: '{k}'")
            expected_input_kwargs.update(_inp_kwargs.keys())

        permitted_input_kwargs = list(filtered_input_kwargs.keys())
        not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)
        if len(not_approved_kwargs) > 0:
            return Err(
                f"Input arguments: {not_approved_kwargs} to the function are not approved yet."
            )
        return Ok(True)


def allowed_ids_only(
    self,
    allowed_inputs,
    kwargs,
    context,
):
    # syft absolute
    from syft import ServerType
    from syft import UID
    from syft.client.api import ServerIdentity

    if context.server.server_type == ServerType.DATASITE:
        server_identity = ServerIdentity(
            server_name=context.server.name,
            server_id=context.server.id,
            verify_key=context.server.signing_key.verify_key,
        )
        allowed_inputs = allowed_inputs.get(server_identity, {})
    else:
        raise Exception(
            f"Invalid Server Type for Code Submission:{context.server.server_type}"
        )
    filtered_kwargs = {}
    for key in allowed_inputs.keys():
        if key in kwargs:
            value = kwargs[key]
            uid = value
            if not isinstance(uid, UID):
                uid = getattr(value, "id", None)

            if uid != allowed_inputs[key]:
                raise Exception(
                    f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                )
            filtered_kwargs[key] = value
    return filtered_kwargs

In [None]:
class RepeatedCallPolicy(sy.CustomOutputPolicy):
    n_calls: int = 0
    downloadable_output_args: list[str] = []
    state: dict[Any, Any] = {}

    def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
        self.downloadable_output_args = (
            downloadable_output_args if downloadable_output_args is not None else []
        )
        self.n_calls = n_calls
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]

    def update_policy(self, context, outputs):
        self.state["counts"] += 1

    def apply_to_output(self, context, outputs, update_policy=True):
        if hasattr(outputs, "syft_action_data"):
            outputs = outputs.syft_action_data
        output_dict = {}
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = outputs[output_arg]
            if update_policy:
                self.update_policy(context, outputs)
        else:
            return None
        return output_dict

    def _is_valid(self, context):
        return self.state["counts"] < self.n_calls

In [None]:
@sy.syft_function(
    input_policy=CustomExactMatch(ages_data=asset),
    output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=["y"]),
)
def how_are_people_dying_statistics_custom(ages_data):
    df = ages_data
    avg_age_death_gender = (
        df.groupby("Gender")["Age of death"].mean().reset_index(name="Avg_Age_of_Death")
    )
    manner_of_death_count = (
        df.groupby("Manner of death")
        .size()
        .reset_index(name="Count")
        .sort_values(by="Count", ascending=False)
    )

    return (manner_of_death_count, avg_age_death_gender)

# Test on mock data

In [None]:
pointer = how_are_people_dying_statistics(ages_data=asset)
result = pointer.get()

In [None]:
result[0]

In [None]:
result[1]

# Submit code

In [None]:
# Create a new project
new_project = sy.Project(
    name="The project about death",
    description="Hi, I want to calculate some statistics on how folks are dying",
    members=[user_client],
)
new_project

In [None]:
result = new_project.create_code_request(how_are_people_dying_statistics, user_client)

In [None]:
result

In [None]:
project = new_project.send()
project

In [None]:
project = user_client.get_project(name="The project about death")
assert project
# assert len(project.events) == 1
# assert isinstance(project.events[0], sy.service.project.project.ProjectRequest)
# assert project.events[0].request.status == RequestStatus.PENDING

In [None]:
project.requests

In [None]:
result = user_client.code.how_are_people_dying_statistics(ages_data=asset)
result

# Code is denied

In [None]:
admin_client.projects

In [None]:
project_view = admin_client.projects[0]
project_view.requests

In [None]:
request = project_view.requests[0]
request

In [None]:
func = request.code
func

In [None]:
func.show_code

In [None]:
asset_view = func.assets[0]
asset_view.data

In [None]:
result = request.deny(
    reason=("The Submitted UserCode is too grim in it's study. Study something else.")
)
result

# Change code

In [None]:
@sy.syft_function(
    input_policy=CustomExactMatch(ages_data=asset),
    output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=["y"]),
)
def how_are_people_dying_statistics(ages_data):
    df = ages_data
    df["Lifespan"] = df["Death year"] - df["Birth year"]
    longest_lifespan = df.sort_values(by="Lifespan", ascending=False).head(1)[
        ["Name", "Lifespan"]
    ]

    return longest_lifespan

In [None]:
# Find out how to send a new request to the same function to project

In [None]:
server.land()