Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync TwinAPIEndpoint #8696

Merged
merged 15 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def private(self, *args: Any, **kwargs: Any) -> Any:
message="This function doesn't support public/private calls as it's not custom."
)

def custom_function_id(self) -> UID | SyftError:
def custom_function_actionobject_id(self) -> UID | SyftError:
if self.custom_function and self.pre_kwargs is not None:
custom_path = self.pre_kwargs.get("path", "")
api_call = SyftAPICall(
Expand All @@ -324,7 +324,7 @@ def custom_function_id(self) -> UID | SyftError:
endpoint = self.make_call(api_call=api_call)
if isinstance(endpoint, SyftError):
return endpoint
return endpoint.id
return endpoint.action_object_id
return SyftError(message="This function is not a custom function")

def _repr_markdown_(self, wrap_as_python: bool = False, indent: int = 0) -> str:
Expand Down Expand Up @@ -389,7 +389,7 @@ def prepare_args_and_kwargs(

for k, v in kwargs.items():
if isinstance(v, RemoteFunction) and v.custom_function:
kwargs[k] = v.custom_function_id()
kwargs[k] = v.custom_function_actionobject_id()

args, kwargs = convert_to_pointers(
api=self.api,
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"CustomAPIView": {
"1": {
"version": 1,
"hash": "21cada3f8b8609e91e4f01f3bfdbdab3f8b96003163e09dba1c4b31041598ca2",
"hash": "5cdfaea6b5af235d0f9ab8e6f01d6cebcde7c8bf4efd00aec26c343965ef865e",
"action": "add"
}
},
Expand Down Expand Up @@ -94,7 +94,7 @@
"TwinAPIEndpoint": {
"1": {
"version": 1,
"hash": "edcd67ab41edfae56deb23d9ef838edc442f587bdb16b8e8c46efa20c04e3c25",
"hash": "03a0244c6b322e72032c5c33c1db09b913f83d72cfad1cc51ac4ec90506afcf5",
"action": "add"
}
},
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _user_code_execute(

if not override_execution_permission:
if input_policy is None:
if not code_item.output_policy_approved:
if not code_item.is_output_policy_approved(context):
return Err("Execution denied: Your code is waiting for approval")
return Err(f"No input policy defined for user code: {code_item.id}")

Expand Down
9 changes: 7 additions & 2 deletions packages/syft/src/syft/service/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
from ...types.transforms import TransformContext
from ...types.transforms import generate_action_object_id
from ...types.transforms import generate_id
from ...types.transforms import transform
from ...types.uid import UID
from ..context import AuthedServiceContext
from ..response import SyftError

Expand Down Expand Up @@ -61,6 +64,7 @@ class TwinAPIEndpointView(SyftObject):
__version__ = SYFT_OBJECT_VERSION_1

path: str
action_object_id: UID
signature: Signature
access: str = "Public"
mock_function: str | None = None
Expand Down Expand Up @@ -234,7 +238,7 @@ class CreateTwinAPIEndpoint(BaseTwinAPIEndpoint):


@serializable()
class TwinAPIEndpoint(SyftObject):
class TwinAPIEndpoint(SyncableSyftObject):
# version
__canonical_name__ = "TwinAPIEndpoint"
__version__ = SYFT_OBJECT_VERSION_1
Expand All @@ -247,6 +251,7 @@ def __init__(self, **kwargs: Any) -> None:
mock_function: PublicAPIEndpoint
signature: Signature
description: str | None = None
action_object_id: UID

__attr_searchable__ = ["path"]
__attr_unique__ = ["path"]
Expand Down Expand Up @@ -461,7 +466,7 @@ def code_string(context: TransformContext) -> TransformContext:

@transform(CreateTwinAPIEndpoint, TwinAPIEndpoint)
def endpoint_create_to_twin_endpoint() -> list[Callable]:
return [generate_id, check_and_cleanup_signature]
return [generate_id, generate_action_object_id, check_and_cleanup_signature]


@transform(TwinAPIEndpoint, TwinAPIEndpointView)
Expand Down
28 changes: 25 additions & 3 deletions packages/syft/src/syft/service/api/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
Expand Down Expand Up @@ -44,11 +45,18 @@ def __init__(self, store: DocumentStore) -> None:
roles=ADMIN_ROLE_LEVEL,
)
def set(
self, context: AuthedServiceContext, endpoint: CreateTwinAPIEndpoint
self,
context: AuthedServiceContext,
endpoint: CreateTwinAPIEndpoint | TwinAPIEndpoint,
) -> SyftSuccess | SyftError:
"""Register an CustomAPIEndpoint."""
try:
new_endpoint = endpoint.to(TwinAPIEndpoint)
if isinstance(endpoint, CreateTwinAPIEndpoint): # type: ignore
new_endpoint = endpoint.to(TwinAPIEndpoint)
elif isinstance(endpoint, TwinAPIEndpoint): # type: ignore
new_endpoint = endpoint
else:
return SyftError(message="Invalid endpoint type.")
except ValueError as e:
return SyftError(message=str(e))

Expand All @@ -68,7 +76,7 @@ def set(

result = result.ok()
action_obj = ActionObject.from_obj(
id=result.id,
id=new_endpoint.action_object_id,
syft_action_data=CustomEndpointActionObject(endpoint_id=result.id),
syft_node_location=context.node.id,
syft_client_verify_key=context.credentials,
Expand Down Expand Up @@ -213,6 +221,17 @@ def api_endpoints(

return api_endpoint_view

@service_method(path="api.get_all", name="get_all", roles=ADMIN_ROLE_LEVEL)
def get_all(
self,
context: AuthedServiceContext,
) -> list[TwinAPIEndpoint] | SyftError:
"""Get all API endpoints."""
result = self.stash.get_all(context.credentials)
if result.is_ok():
return result.ok()
return SyftError(message=result.err())

@service_method(path="api.call", name="call", roles=GUEST_ROLE_LEVEL)
def call(
self,
Expand Down Expand Up @@ -357,3 +376,6 @@ def get_code(
return result.ok()

return SyftError(message=f"Unable to get {endpoint_path} CustomAPIEndpoint")


TYPE_TO_SERVICE[TwinAPIEndpoint] = APIService
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def partition_by_node(kwargs: dict[str, Any]) -> dict[NodeIdentity, dict[str, UI
if isinstance(v, TwinObject):
uid = v.id
if isinstance(v, RemoteFunction):
uid = v.custom_function_id()
uid = v.custom_function_actionobject_id()
if isinstance(v, Asset):
uid = v.action_id
if not isinstance(uid, UID):
Expand Down
31 changes: 18 additions & 13 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..action.action_permissions import ActionObjectPermission
from ..action.action_permissions import ActionPermission
from ..action.action_permissions import StoragePermission
from ..api.api import TwinAPIEndpoint
from ..code.user_code import UserCode
from ..code.user_code import UserCodeStatusCollection
from ..job.job_stash import Job
Expand Down Expand Up @@ -776,6 +777,10 @@ def visual_hierarchy(self) -> tuple[type, dict]:
ExecutionOutput: [Job],
Job: [ActionObject, SyftLog, Job],
}
elif isinstance(root_obj, TwinAPIEndpoint):
return TwinAPIEndpoint, { # type: ignore
TwinAPIEndpoint: [],
}
else:
raise ValueError(f"Unknown root type: {self.root.obj_type}")

Expand Down Expand Up @@ -1160,12 +1165,11 @@ def hierarchies(

for diff in obj_uid_to_diff.values():
diff_obj = diff.low_obj if diff.low_obj is not None else diff.high_obj
if isinstance(diff_obj, Request):
root_ids.append(diff.object_id)
elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore
root_ids.append(diff.object_id) # type: ignore
elif isinstance(diff_obj, UserCode):
if isinstance(diff_obj, Request | UserCode | TwinAPIEndpoint):
# TODO: Figure out nested user codes, do we even need that?

root_ids.append(diff.object_id) # type: ignore
elif isinstance(diff_obj, Job) and diff_obj.parent_job_id is None: # type: ignore
root_ids.append(diff.object_id) # type: ignore

for root_uid in root_ids:
Expand Down Expand Up @@ -1218,14 +1222,15 @@ def from_widget_state(
if sync_direction == SyncDirection.HIGH_TO_LOW:
if widget.share_private_data or diff.object_type == "Job":
if share_to_user is None:
raise ValueError("empty to user to share with")
new_permissions_low_side = [
ActionObjectPermission(
uid=widget.diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user, # type: ignore
)
]
raise ValueError("share_to_user is required for private data")
else:
new_permissions_low_side = [
ActionObjectPermission(
uid=widget.diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user, # type: ignore
)
]

# mockify
mockify = widget.mockify
Expand Down
9 changes: 8 additions & 1 deletion packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..action.action_permissions import ActionObjectPermission
from ..action.action_permissions import ActionPermission
from ..action.action_permissions import StoragePermission
from ..api.api import TwinAPIEndpoint
from ..code.user_code import UserCodeStatusCollection
from ..context import AuthedServiceContext
from ..job.job_stash import Job
Expand Down Expand Up @@ -150,6 +151,12 @@ def set_object(
creds = context.credentials

exists = stash.get_by_uid(context.credentials, item.id).ok() is not None

if isinstance(item, TwinAPIEndpoint):
return context.node.get_service("apiservice").set(
context=context, endpoint=item
)

if exists:
res = stash.update(creds, item)
else:
Expand Down Expand Up @@ -254,6 +261,7 @@ def get_all_syncable_items(
"logservice",
"outputservice",
"usercodestatusservice",
"apiservice",
]

for service_name in services_to_sync:
Expand All @@ -263,7 +271,6 @@ def get_all_syncable_items(
return items
all_items.extend(items)

# NOTE we only need action objects from outputs for now
action_object_ids = set()
for obj in all_items:
if isinstance(obj, ExecutionOutput):
Expand Down
10 changes: 10 additions & 0 deletions packages/syft/src/syft/types/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def generate_id(context: TransformContext) -> TransformContext:
return context


def generate_action_object_id(context: TransformContext) -> TransformContext:
if context.output is None:
return context
if "action_object_id" not in context.output or not isinstance(
context.output["action_object_id"], UID
):
context.output["action_object_id"] = UID()
return context


def validate_url(context: TransformContext) -> TransformContext:
if context.output and context.output["url"] is not None:
context.output["url"] = GridURL.from_url(context.output["url"]).url_no_port
Expand Down
55 changes: 55 additions & 0 deletions packages/syft/tests/syft/service/sync/sync_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from syft.client.syncing import compare_clients
from syft.client.syncing import compare_states
from syft.client.syncing import resolve
from syft.client.syncing import resolve_single
from syft.service.action.action_object import ActionObject
from syft.service.response import SyftError

Expand Down Expand Up @@ -292,6 +293,60 @@ def skip_if_user_code(diff):
)


@sy.mock_api_endpoint()
def mock_function(context) -> str:
return -42


@sy.private_api_endpoint()
def private_function(context) -> str:
return 42


def test_twin_api_integration(low_worker, high_worker):
low_client = low_worker.root_client
high_client = high_worker.root_client

low_client.register(
email="newuser@openmined.org",
name="John Doe",
password="pw",
password_verify="pw",
)

client_low_ds = low_client.login(
email="newuser@openmined.org",
password="pw",
)

new_endpoint = sy.TwinAPIEndpoint(
path="testapi.query",
private_function=private_function,
mock_function=mock_function,
description="",
)
high_client.api.services.api.add(endpoint=new_endpoint)
high_client.refresh()
high_private_res = high_client.api.services.testapi.query.private()
assert high_private_res == 42

low_state = low_client.get_sync_state()
high_state = high_client.get_sync_state()
diff_state = compare_states(high_state, low_state)
obj_diff_batch = diff_state[0]
widget = resolve_single(obj_diff_batch)
widget.click_sync()

client_low_ds.refresh()
low_private_res = client_low_ds.api.services.testapi.query.private()
assert isinstance(
low_private_res, SyftError
), "Should not have access to private on low side"
low_mock_res = client_low_ds.api.services.testapi.query.mock()
high_mock_res = high_client.api.services.testapi.query.mock()
assert low_mock_res == high_mock_res == -42


def test_skip_user_code(low_worker, high_worker):
low_client = low_worker.root_client
client_low_ds = low_worker.guest_client
Expand Down