Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle empty kwargs and high state update in high-to-low case #8718

Merged
merged 11 commits into from
Apr 18, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ notebooks/helm/scenario_data.jsonl

# tox syft.build.helm generated file
out.*
.git-blame-ignore-revs
3 changes: 2 additions & 1 deletion packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ def __getattribute__(self, name: str) -> Any:
except AttributeError:
raise SyftAttributeError(
f"'APIModule' api{self.path} object has no submodule or method '{name}', "
"you may not have permission to access the module you are trying to access"
"you may not have permission to access the module you are trying to access."
"If you think this is an error, try calling `client.refresh()` to update the API."
)

def __getitem__(self, key: str | int) -> Any:
Expand Down
5 changes: 5 additions & 0 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,11 @@ def __rrshift__(self, other: Any) -> Any:

@serializable()
class AnyActionObject(ActionObject):
"""
This is a catch-all class for all objects that are not
defined in the `action_types` dictionary.
"""

__canonical_name__ = "AnyActionObject"
__version__ = SYFT_OBJECT_VERSION_3

Expand Down
10 changes: 4 additions & 6 deletions packages/syft/src/syft/service/action/action_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@ def action_type_for_type(obj_or_type: Any) -> type:
obj_or_type: Union[object, type]
Can be an object or a class
"""
if isinstance(obj_or_type, ActionDataEmpty):
obj_or_type = obj_or_type.syft_internal_type
if type(obj_or_type) != type:
if isinstance(obj_or_type, ActionDataEmpty):
obj_or_type = obj_or_type.syft_internal_type
else:
obj_or_type = type(obj_or_type)
obj_or_type = type(obj_or_type)

if obj_or_type not in action_types:
debug(f"WARNING: No Type for {obj_or_type}, returning {action_types[Any]}")
return action_types[Any]

return action_types[obj_or_type]
return action_types.get(obj_or_type, action_types[Any])


def action_type_for_object(obj: Any) -> type:
Expand Down
4 changes: 3 additions & 1 deletion packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def keep_owned_kwargs(
def is_execution_on_owned_args(
self, kwargs: dict[str, Any], context: AuthedServiceContext
) -> bool:
return len(self.keep_owned_kwargs(kwargs, context)) == len(kwargs)
return bool(kwargs) and len(self.keep_owned_kwargs(kwargs, context)) == len(
kwargs
)

@service_method(path="code.call", name="call", roles=GUEST_ROLE_LEVEL)
def call(
Expand Down
42 changes: 41 additions & 1 deletion packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing_extensions import Self

# relative
from ...client.api import APIRegistry
from ...client.client import SyftClient
from ...client.sync_decision import SyncDecision
from ...client.sync_decision import SyncDirection
Expand Down Expand Up @@ -585,6 +586,15 @@ def target_node_uid(self) -> UID:
else:
return self.low_node_uid

@property
def source_node_uid(self) -> UID:
if self.sync_direction is None:
raise ValueError("no direction specified")
if self.sync_direction == SyncDirection.LOW_TO_HIGH:
return self.low_node_uid
else:
return self.high_node_uid

@property
def target_verify_key(self) -> SyftVerifyKey:
if self.sync_direction is None:
Expand All @@ -594,6 +604,35 @@ def target_verify_key(self) -> SyftVerifyKey:
else:
return self.user_verify_key_low

@property
def source_verify_key(self) -> SyftVerifyKey:
if self.sync_direction is None:
raise ValueError("no direction specified")
if self.sync_direction == SyncDirection.LOW_TO_HIGH:
return self.user_verify_key_low
else:
return self.user_verify_key_high

@property
def source_client(self) -> SyftClient:
return self.build(self.source_node_uid, self.source_verify_key)

@property
def target_client(self) -> SyftClient:
return self.build(self.target_node_uid, self.target_verify_key)

def build(self, node_uid: UID, syft_client_verify_key: SyftVerifyKey): # type: ignore
# relative
from ...client.domain_client import DomainClient

api = APIRegistry.api_for(node_uid, syft_client_verify_key)
client = DomainClient(
api=api,
connection=api.connection, # type: ignore
credentials=api.signing_key, # type: ignore
)
return client

def get_dependencies(
self,
include_roots: bool = False,
Expand Down Expand Up @@ -1241,8 +1280,9 @@ def from_widget_state(
)
]

# mockify
mockify = widget.mockify
if widget.has_unused_share_button:
print("Share button was not used, so we will mockify the object")

# storage permissions
new_storage_permissions = []
Expand Down
52 changes: 31 additions & 21 deletions packages/syft/src/syft/service/sync/resolve_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ipywidgets import VBox

# relative
from ...client.api import APIRegistry
from ...client.sync_decision import SyncDecision
from ...client.sync_decision import SyncDirection
from ...node.credentials import SyftVerifyKey
Expand Down Expand Up @@ -120,6 +119,11 @@ def set_share_private_data(self) -> None:
def mockify(self) -> bool:
return not self.share_private_data

@property
def has_unused_share_button(self) -> bool:
# does not have share button
return False

@property
def share_private_data(self) -> bool:
# there are TwinAPIEndpoint.__private_sync_attr_mocks__
Expand Down Expand Up @@ -191,11 +195,15 @@ def __init__(
def mockify(self) -> bool:
if isinstance(self.diff.non_empty_object, TwinAPIEndpoint):
return True
if self.show_share_button and not self.share_private_data:
if self.has_unused_share_button:
return True
else:
return False

@property
def has_unused_share_button(self) -> bool:
return self.show_share_button and not self.share_private_data

@property
def show_share_button(self) -> bool:
return isinstance(self.diff.non_empty_object, SyftLog | ActionObject)
Expand Down Expand Up @@ -417,8 +425,12 @@ def button_callback(self, *args: list, **kwargs: dict) -> SyftSuccess | SyftErro
# Maybe default read permission for some objects (high -> low)

# TODO: UID
resolved_state_low = ResolvedSyncState(node_uid=UID(), alias="low")
resolved_state_high = ResolvedSyncState(node_uid=UID(), alias="high")
resolved_state_low = ResolvedSyncState(
node_uid=self.obj_diff_batch.low_node_uid, alias="low"
)
resolved_state_high = ResolvedSyncState(
node_uid=self.obj_diff_batch.high_node_uid, alias="high"
)

batch_diff = self.obj_diff_batch
if batch_diff.is_unchanged:
Expand Down Expand Up @@ -486,25 +498,23 @@ def button_callback(self, *args: list, **kwargs: dict) -> SyftSuccess | SyftErro
resolved_state_low.add_sync_instruction(sync_instruction)
resolved_state_high.add_sync_instruction(sync_instruction)

# TODO: ONLY WORKS FOR LOW TO HIGH
# relative
from ...client.domain_client import DomainClient

api = APIRegistry.api_for(
self.obj_diff_batch.target_node_uid, self.obj_diff_batch.target_verify_key
)
client = DomainClient(
api=api,
connection=api.connection, # type: ignore
credentials=api.signing_key, # type: ignore
)

if self.obj_diff_batch.sync_direction is None:
raise ValueError("no direction specified")
if self.obj_diff_batch.sync_direction == SyncDirection.LOW_TO_HIGH:
res = client.apply_state(resolved_state_high)
else:
res = client.apply_state(resolved_state_low)
sync_direction = self.obj_diff_batch.sync_direction
resolved_state = (
resolved_state_high
if sync_direction == SyncDirection.LOW_TO_HIGH
else resolved_state_low
)
res = self.obj_diff_batch.target_client.apply_state(resolved_state)

if sync_direction == SyncDirection.HIGH_TO_LOW:
# apply empty state to generete a new state
resolved_state_high = ResolvedSyncState(
node_uid=self.obj_diff_batch.high_node_uid, alias="high"
)
high_client = self.obj_diff_batch.source_client
res = high_client.apply_state(resolved_state_high)

self.is_synced = True
self.set_result_state(res)
Expand Down
126 changes: 37 additions & 89 deletions packages/syft/tests/syft/service/sync/sync_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,52 @@
import pytest

# syft absolute
import syft
import syft as sy
from syft.abstract_node import NodeSideType
from syft.client.domain_client import DomainClient
from syft.client.sync_decision import SyncDecision
from syft.client.syncing import compare_clients
from syft.client.syncing import compare_states
from syft.client.syncing import resolve
from syft.client.syncing import resolve_single
from syft.service.action.action_object import ActionObject
from syft.service.response import SyftError
from syft.service.response import SyftSuccess


def compare_and_resolve(*, from_client: DomainClient, to_client: DomainClient):
diff_state_before = compare_clients(from_client, to_client)
for obj_diff_batch in diff_state_before.batches:
widget = resolve_single(obj_diff_batch)
widget.click_share_all_private_data()
res = widget.click_sync()
assert isinstance(res, SyftSuccess)
from_client.refresh()
to_client.refresh()
diff_state_after = compare_clients(from_client, to_client)
return diff_state_before, diff_state_after


def run_and_accept_result(client):
job_high = client.code.compute(blocking=True)
client.requests[0].accept_by_depositing_result(job_high)
return job_high


@syft.syft_function_single_use()
def compute() -> int:
return 42


def get_ds_client(client: DomainClient) -> DomainClient:
client.register(
name="a",
email="a@a.com",
password="asdf",
password_verify="asdf",
)
return client.login(email="a@a.com", password="asdf")


@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
Expand Down Expand Up @@ -211,41 +248,6 @@ def compute_mean(data) -> float:
high_worker.cleanup()


def test_diff_state(low_worker, high_worker):
low_client = low_worker.root_client
client_low_ds = low_worker.guest_client
high_client = high_worker.root_client

@sy.syft_function_single_use()
def compute() -> int:
return 42

compute.code = dedent(compute.code)

_ = client_low_ds.code.request_code_execution(compute)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state, decision="low", share_private_objects=True
)

assert not diff_state.is_same
assert not low_items_to_sync.is_empty
assert not high_items_to_sync.is_empty

low_client.apply_state(low_items_to_sync)
high_client.apply_state(high_items_to_sync)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state, decision="low", share_private_objects=True
)

assert diff_state.is_same
assert low_items_to_sync.is_empty
assert high_items_to_sync.is_empty


def test_forget_usercode(low_worker, high_worker):
low_client = low_worker.root_client
client_low_ds = low_worker.guest_client
Expand Down Expand Up @@ -293,60 +295,6 @@ def skip_if_user_code(diff):
)


@sy.mock_api_endpoint()
def mock_function(context) -> str:
return -42


@sy.private_api_endpoint()
def private_function(context) -> str:
return 42


def test_twin_api_integration(low_worker, high_worker):
low_client = low_worker.root_client
high_client = high_worker.root_client

low_client.register(
email="newuser@openmined.org",
name="John Doe",
password="pw",
password_verify="pw",
)

client_low_ds = low_client.login(
email="newuser@openmined.org",
password="pw",
)

new_endpoint = sy.TwinAPIEndpoint(
path="testapi.query",
private_function=private_function,
mock_function=mock_function,
description="",
)
high_client.api.services.api.add(endpoint=new_endpoint)
high_client.refresh()
high_private_res = high_client.api.services.testapi.query.private()
assert high_private_res == 42

low_state = low_client.get_sync_state()
high_state = high_client.get_sync_state()
diff_state = compare_states(high_state, low_state)
obj_diff_batch = diff_state[0]
widget = resolve_single(obj_diff_batch)
widget.click_sync()

client_low_ds.refresh()
low_private_res = client_low_ds.api.services.testapi.query.private()
assert isinstance(
low_private_res, SyftError
), "Should not have access to private on low side"
low_mock_res = client_low_ds.api.services.testapi.query.mock()
high_mock_res = high_client.api.services.testapi.query.mock()
assert low_mock_res == high_mock_res == -42


def test_skip_user_code(low_worker, high_worker):
low_client = low_worker.root_client
client_low_ds = low_worker.guest_client
Expand Down