Skip to content

Commit

Permalink
Merge pull request #8603 from OpenMined/eelco/updated_private_object
Browse files Browse the repository at this point in the history
[WIP] Syncing updated private objects
  • Loading branch information
eelcovdw committed Mar 25, 2024
2 parents 2f4f76f + eb8c740 commit e6e53a7
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 54 deletions.
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Expand Up @@ -960,9 +960,9 @@ def decorator(f: Any) -> SubmitUserCode:
)

if share_results_with_owners and res.output_policy_init_kwargs is not None:
res.output_policy_init_kwargs[
"output_readers"
] = res.input_owner_verify_keys
res.output_policy_init_kwargs["output_readers"] = (
res.input_owner_verify_keys
)

success_message = SyftSuccess(
message=f"Syft function '{f.__name__}' successfully created. "
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/output/output_service.py
@@ -1,5 +1,4 @@
# stdlib
from typing import Any
from typing import ClassVar

# third party
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/request/request.py
Expand Up @@ -949,9 +949,9 @@ def check_requesting_user_verify_key(context: TransformContext) -> TransformCont
if context.obj.requesting_user_verify_key and context.node.is_root(
context.credentials
):
context.output[
"requesting_user_verify_key"
] = context.obj.requesting_user_verify_key
context.output["requesting_user_verify_key"] = (
context.obj.requesting_user_verify_key
)
else:
context.output["requesting_user_verify_key"] = context.credentials

Expand Down
21 changes: 11 additions & 10 deletions packages/syft/src/syft/service/sync/diff_state.py
@@ -1,11 +1,3 @@
"""
How to check differences between two objects:
* by default merge every attr
* check if there is a custom implementation of the check function
* check if there are exceptions we do not want to merge
* check if there are some restrictions on the attr set
"""

# stdlib
import html
import textwrap
Expand Down Expand Up @@ -160,6 +152,8 @@ class ObjectDiff(SyftObject): # StateTuple (compare 2 objects)
high_permissions: list[str] = []
low_storage_permissions: set[UID] = set()
high_storage_permissions: set[UID] = set()
low_status: str | None = None
high_status: str | None = None

obj_type: type
diff_list: list[AttrDiff] = []
Expand Down Expand Up @@ -198,6 +192,8 @@ def from_objects(
cls,
low_obj: SyncableSyftObject | None,
high_obj: SyncableSyftObject | None,
low_status: str | None,
high_status: str | None,
low_permissions: set[str],
high_permissions: set[str],
low_storage_permissions: set[UID],
Expand All @@ -212,6 +208,8 @@ def from_objects(
res = cls(
low_obj=low_obj,
high_obj=high_obj,
low_status=low_status,
high_status=high_status,
obj_type=obj_type,
low_node_uid=low_node_uid,
high_node_uid=high_node_uid,
Expand All @@ -224,8 +222,8 @@ def from_objects(
if (
low_obj is None
or high_obj is None
or res.is_mock("low")
or res.is_mock("high")
or (res.is_mock("low") and high_status == "SAME")
or (res.is_mock("high") and low_status == "SAME")
):
diff_list = []
else:
Expand Down Expand Up @@ -612,9 +610,12 @@ def from_sync_state(
high_obj = high_state.objects.get(obj_id, None)
high_permissions = high_state.permissions.get(obj_id, set())
high_storage_permissions = high_state.storage_permissions.get(obj_id, set())

diff = ObjectDiff.from_objects(
low_obj=low_obj,
high_obj=high_obj,
low_status=low_state.get_status(obj_id),
high_status=high_state.get_status(obj_id),
low_permissions=low_permissions,
high_permissions=high_permissions,
low_storage_permissions=low_storage_permissions,
Expand Down
55 changes: 30 additions & 25 deletions packages/syft/src/syft/service/sync/sync_service.py
Expand Up @@ -222,17 +222,11 @@ def get_permissions(
storage_permissions[_id] = store.storage_permissions[_id]
return permissions, storage_permissions

@service_method(
path="sync._get_state",
name="_get_state",
roles=ADMIN_ROLE_LEVEL,
)
def _get_state(
self, context: AuthedServiceContext, add_to_store: bool = False
) -> SyncState | SyftError:
def get_all_items(
self, context: AuthedServiceContext
) -> list[SyncableSyftObject] | SyftError:
node = cast(AbstractNode, context.node)

new_state = SyncState(node_uid=node.id)
all_items = []

services_to_sync = [
"requestservice",
Expand All @@ -246,46 +240,57 @@ def _get_state(
for service_name in services_to_sync:
service = node.get_service(service_name)
items = service.get_all(context)
new_state.add_objects(items) # type: ignore

# TODO workaround, we only need action objects from outputs for now
if isinstance(items, SyftError):
return items
all_items.extend(items)

# NOTE we only need action objects from outputs for now
action_object_ids = set()
for obj in new_state.objects.values():
for obj in all_items:
if isinstance(obj, ExecutionOutput):
action_object_ids |= set(obj.output_id_list)
elif isinstance(obj, Job) and obj.result is not None:
if isinstance(obj.result, ActionObject):
obj.result = obj.result.as_empty()
action_object_ids.add(obj.result.id)

action_objects = []
for uid in action_object_ids:
action_object = node.get_service("actionservice").get(
context, uid, resolve_nested=False
) # type: ignore
if action_object.is_err():
return SyftError(message=action_object.err())
action_objects.append(action_object.ok())

new_state.add_objects(action_objects)
all_items.append(action_object.ok())

new_state._build_dependencies() # type: ignore
return all_items

permissions, storage_permissions = self.get_permissions(
context, new_state.objects.values()
)
new_state.permissions = permissions
new_state.storage_permissions = storage_permissions
@service_method(
path="sync._get_state",
name="_get_state",
roles=ADMIN_ROLE_LEVEL,
)
def _get_state(
self, context: AuthedServiceContext, add_to_store: bool = False
) -> SyncState | SyftError:
objects = self.get_all_items(context)
permissions, storage_permissions = self.get_permissions(context, objects)

previous_state = self.stash.get_latest(context=context)
if previous_state is not None:
new_state.previous_state_link = LinkedObject.from_obj(
previous_state_link = LinkedObject.from_obj(
obj=previous_state,
service_type=SyncService,
node_uid=context.node.id, # type: ignore
)

new_state = SyncState.from_objects(
node_uid=context.node.id, # type: ignore
objects=objects,
permissions=permissions,
storage_permissions=storage_permissions,
previous_state_link=previous_state_link,
)

if add_to_store:
self.stash.set(context.credentials, new_state)

Expand Down
54 changes: 42 additions & 12 deletions packages/syft/src/syft/service/sync/sync_state.py
Expand Up @@ -82,9 +82,37 @@ class SyncState(SyftObject):
previous_state_link: LinkedObject | None = None
permissions: dict[UID, set[str]] = {}
storage_permissions: dict[UID, set[UID]] = {}
previous_state_diff: "NodeDiff" | None = None # type: ignore

__attr_searchable__ = ["created_at"]

@classmethod
def from_objects(
cls,
node_uid: UID,
objects: list[SyncableSyftObject],
permissions: dict[UID, set[str]],
storage_permissions: dict[UID, set[UID]],
previous_state_link: LinkedObject | None = None,
) -> "SyncState":
state = cls(
node_uid=node_uid,
previous_state_link=previous_state_link,
)

state._add_objects(objects)
return state

def _set_previous_state_diff(self) -> None:
# relative
from .diff_state import NodeDiff

# Re-use NodeDiff to compare to previous state
# Low = previous state, high = current state
# NOTE No previous sync state means everything is new
previous_state = self.previous_state or SyncState(node_uid=self.node_uid)
self.previous_state_diff = NodeDiff.from_sync_state(previous_state, self)

@property
def previous_state(self) -> Optional["SyncState"]:
if self.previous_state_link is not None:
Expand All @@ -95,7 +123,16 @@ def previous_state(self) -> Optional["SyncState"]:
def all_ids(self) -> set[UID]:
return set(self.objects.keys())

def add_objects(self, objects: list[SyncableSyftObject]) -> None:
def get_status(self, uid: UID) -> str | None:
if self.previous_state_diff is None:
return None
diff = self.previous_state_diff.obj_uid_to_diff.get(uid)

if diff is None:
return None
return diff.status

def _add_objects(self, objects: list[SyncableSyftObject]) -> None:
for obj in objects:
if isinstance(obj.id, LineageID):
self.objects[obj.id.id] = obj
Expand All @@ -106,6 +143,7 @@ def add_objects(self, objects: list[SyncableSyftObject]) -> None:
# need to build dependencies every time to not have UIDs
# in dependencies that are not in objects
self._build_dependencies()
self._set_previous_state_diff()

def _build_dependencies(self) -> None:
self.dependencies = {}
Expand All @@ -122,22 +160,14 @@ def _build_dependencies(self) -> None:
if len(deps):
self.dependencies[obj.id] = deps

def get_previous_state_diff(self) -> "NodeDiff":
# relative
from .diff_state import NodeDiff

# Re-use NodeDiff to compare to previous state
# Low = previous state, high = current state
# NOTE No previous sync state means everything is new
previous_state = self.previous_state or SyncState(node_uid=self.node_uid)
return NodeDiff.from_sync_state(previous_state, self)

@property
def rows(self) -> list[SyncStateRow]:
result = []
ids = set()

previous_diff = self.get_previous_state_diff()
previous_diff = self.previous_state_diff
if previous_diff is None:
raise ValueError("No previous state to compare to")
for hierarchy in previous_diff.hierarchies:
for diff, level in zip(hierarchy.diffs, hierarchy.hierarchy_levels):
if diff.object_id in ids:
Expand Down

0 comments on commit e6e53a7

Please sign in to comment.