Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syncing updated private objects #8603

Merged
merged 2 commits into from Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Expand Up @@ -960,9 +960,9 @@ def decorator(f: Any) -> SubmitUserCode:
)

if share_results_with_owners and res.output_policy_init_kwargs is not None:
res.output_policy_init_kwargs[
"output_readers"
] = res.input_owner_verify_keys
res.output_policy_init_kwargs["output_readers"] = (
res.input_owner_verify_keys
)

success_message = SyftSuccess(
message=f"Syft function '{f.__name__}' successfully created. "
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/output/output_service.py
@@ -1,5 +1,4 @@
# stdlib
from typing import Any
from typing import ClassVar

# third party
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/request/request.py
Expand Up @@ -949,9 +949,9 @@ def check_requesting_user_verify_key(context: TransformContext) -> TransformCont
if context.obj.requesting_user_verify_key and context.node.is_root(
context.credentials
):
context.output[
"requesting_user_verify_key"
] = context.obj.requesting_user_verify_key
context.output["requesting_user_verify_key"] = (
context.obj.requesting_user_verify_key
)
else:
context.output["requesting_user_verify_key"] = context.credentials

Expand Down
21 changes: 11 additions & 10 deletions packages/syft/src/syft/service/sync/diff_state.py
@@ -1,11 +1,3 @@
"""
How to check differences between two objects:
* by default merge every attr
* check if there is a custom implementation of the check function
* check if there are exceptions we do not want to merge
* check if there are some restrictions on the attr set
"""

# stdlib
import html
import textwrap
Expand Down Expand Up @@ -160,6 +152,8 @@ class ObjectDiff(SyftObject): # StateTuple (compare 2 objects)
high_permissions: list[str] = []
low_storage_permissions: set[UID] = set()
high_storage_permissions: set[UID] = set()
low_status: str | None = None
high_status: str | None = None

obj_type: type
diff_list: list[AttrDiff] = []
Expand Down Expand Up @@ -198,6 +192,8 @@ def from_objects(
cls,
low_obj: SyncableSyftObject | None,
high_obj: SyncableSyftObject | None,
low_status: str | None,
high_status: str | None,
low_permissions: set[str],
high_permissions: set[str],
low_storage_permissions: set[UID],
Expand All @@ -212,6 +208,8 @@ def from_objects(
res = cls(
low_obj=low_obj,
high_obj=high_obj,
low_status=low_status,
high_status=high_status,
obj_type=obj_type,
low_node_uid=low_node_uid,
high_node_uid=high_node_uid,
Expand All @@ -224,8 +222,8 @@ def from_objects(
if (
low_obj is None
or high_obj is None
or res.is_mock("low")
or res.is_mock("high")
or (res.is_mock("low") and high_status == "SAME")
or (res.is_mock("high") and low_status == "SAME")
):
diff_list = []
else:
Expand Down Expand Up @@ -612,9 +610,12 @@ def from_sync_state(
high_obj = high_state.objects.get(obj_id, None)
high_permissions = high_state.permissions.get(obj_id, set())
high_storage_permissions = high_state.storage_permissions.get(obj_id, set())

diff = ObjectDiff.from_objects(
low_obj=low_obj,
high_obj=high_obj,
low_status=low_state.get_status(obj_id),
high_status=high_state.get_status(obj_id),
low_permissions=low_permissions,
high_permissions=high_permissions,
low_storage_permissions=low_storage_permissions,
Expand Down
55 changes: 30 additions & 25 deletions packages/syft/src/syft/service/sync/sync_service.py
Expand Up @@ -222,17 +222,11 @@ def get_permissions(
storage_permissions[_id] = store.storage_permissions[_id]
return permissions, storage_permissions

@service_method(
path="sync._get_state",
name="_get_state",
roles=ADMIN_ROLE_LEVEL,
)
def _get_state(
self, context: AuthedServiceContext, add_to_store: bool = False
) -> SyncState | SyftError:
def get_all_items(
self, context: AuthedServiceContext
) -> list[SyncableSyftObject] | SyftError:
node = cast(AbstractNode, context.node)

new_state = SyncState(node_uid=node.id)
all_items = []

services_to_sync = [
"requestservice",
Expand All @@ -246,46 +240,57 @@ def _get_state(
for service_name in services_to_sync:
service = node.get_service(service_name)
items = service.get_all(context)
new_state.add_objects(items) # type: ignore

# TODO workaround, we only need action objects from outputs for now
if isinstance(items, SyftError):
return items
all_items.extend(items)

# NOTE we only need action objects from outputs for now
action_object_ids = set()
for obj in new_state.objects.values():
for obj in all_items:
if isinstance(obj, ExecutionOutput):
action_object_ids |= set(obj.output_id_list)
elif isinstance(obj, Job) and obj.result is not None:
if isinstance(obj.result, ActionObject):
obj.result = obj.result.as_empty()
action_object_ids.add(obj.result.id)

action_objects = []
for uid in action_object_ids:
action_object = node.get_service("actionservice").get(
context, uid, resolve_nested=False
) # type: ignore
if action_object.is_err():
return SyftError(message=action_object.err())
action_objects.append(action_object.ok())

new_state.add_objects(action_objects)
all_items.append(action_object.ok())

new_state._build_dependencies() # type: ignore
return all_items

permissions, storage_permissions = self.get_permissions(
context, new_state.objects.values()
)
new_state.permissions = permissions
new_state.storage_permissions = storage_permissions
@service_method(
path="sync._get_state",
name="_get_state",
roles=ADMIN_ROLE_LEVEL,
)
def _get_state(
self, context: AuthedServiceContext, add_to_store: bool = False
) -> SyncState | SyftError:
objects = self.get_all_items(context)
permissions, storage_permissions = self.get_permissions(context, objects)

previous_state = self.stash.get_latest(context=context)
if previous_state is not None:
new_state.previous_state_link = LinkedObject.from_obj(
previous_state_link = LinkedObject.from_obj(
obj=previous_state,
service_type=SyncService,
node_uid=context.node.id, # type: ignore
)

new_state = SyncState.from_objects(
node_uid=context.node.id, # type: ignore
objects=objects,
permissions=permissions,
storage_permissions=storage_permissions,
previous_state_link=previous_state_link,
)

if add_to_store:
self.stash.set(context.credentials, new_state)

Expand Down
54 changes: 42 additions & 12 deletions packages/syft/src/syft/service/sync/sync_state.py
Expand Up @@ -82,9 +82,37 @@ class SyncState(SyftObject):
previous_state_link: LinkedObject | None = None
permissions: dict[UID, set[str]] = {}
storage_permissions: dict[UID, set[UID]] = {}
previous_state_diff: "NodeDiff" | None = None # type: ignore

__attr_searchable__ = ["created_at"]

@classmethod
def from_objects(
cls,
node_uid: UID,
objects: list[SyncableSyftObject],
permissions: dict[UID, set[str]],
storage_permissions: dict[UID, set[UID]],
previous_state_link: LinkedObject | None = None,
) -> "SyncState":
state = cls(
node_uid=node_uid,
previous_state_link=previous_state_link,
)

state._add_objects(objects)
return state

def _set_previous_state_diff(self) -> None:
# relative
from .diff_state import NodeDiff

# Re-use NodeDiff to compare to previous state
# Low = previous state, high = current state
# NOTE No previous sync state means everything is new
previous_state = self.previous_state or SyncState(node_uid=self.node_uid)
self.previous_state_diff = NodeDiff.from_sync_state(previous_state, self)

@property
def previous_state(self) -> Optional["SyncState"]:
if self.previous_state_link is not None:
Expand All @@ -95,7 +123,16 @@ def previous_state(self) -> Optional["SyncState"]:
def all_ids(self) -> set[UID]:
return set(self.objects.keys())

def add_objects(self, objects: list[SyncableSyftObject]) -> None:
def get_status(self, uid: UID) -> str | None:
if self.previous_state_diff is None:
return None
diff = self.previous_state_diff.obj_uid_to_diff.get(uid)

if diff is None:
return None
return diff.status

def _add_objects(self, objects: list[SyncableSyftObject]) -> None:
for obj in objects:
if isinstance(obj.id, LineageID):
self.objects[obj.id.id] = obj
Expand All @@ -106,6 +143,7 @@ def add_objects(self, objects: list[SyncableSyftObject]) -> None:
# need to build dependencies every time to not have UIDs
# in dependencies that are not in objects
self._build_dependencies()
self._set_previous_state_diff()

def _build_dependencies(self) -> None:
self.dependencies = {}
Expand All @@ -122,22 +160,14 @@ def _build_dependencies(self) -> None:
if len(deps):
self.dependencies[obj.id] = deps

def get_previous_state_diff(self) -> "NodeDiff":
# relative
from .diff_state import NodeDiff

# Re-use NodeDiff to compare to previous state
# Low = previous state, high = current state
# NOTE No previous sync state means everything is new
previous_state = self.previous_state or SyncState(node_uid=self.node_uid)
return NodeDiff.from_sync_state(previous_state, self)

@property
def rows(self) -> list[SyncStateRow]:
result = []
ids = set()

previous_diff = self.get_previous_state_diff()
previous_diff = self.previous_state_diff
if previous_diff is None:
raise ValueError("No previous state to compare to")
for hierarchy in previous_diff.hierarchies:
for diff, level in zip(hierarchy.diffs, hierarchy.hierarchy_levels):
if diff.object_id in ids:
Expand Down