Skip to content

Commit

Permalink
Merge pull request #8630 from OpenMined/batch-dependencies
Browse files Browse the repository at this point in the history
Batch dependencies
  • Loading branch information
koenvanderveen committed Mar 26, 2024
2 parents be30352 + c9a0781 commit 8e433e4
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 23 deletions.
1 change: 1 addition & 0 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def apply_state(self, resolved_state: ResolvedSyncState) -> SyftSuccess | SyftEr
resolved_state.new_permissions,
resolved_state.new_storage_permissions,
ignored_batches,
unignored_batches=resolved_state.unignored_batches,
)
if isinstance(res, SyftError):
return res
Expand Down
36 changes: 25 additions & 11 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,35 @@ def handle_ignore_skip(
batch: ObjectDiffBatch, decision: SyncDecision, other_batches: list[ObjectDiffBatch]
) -> None:
if decision == SyncDecision.skip or decision == SyncDecision.ignore:
skipped_or_ignored_ids = {
x.object_id for x in batch.get_dependents(include_roots=True)
}
skipped_or_ignored_ids = set(
[x.object_id for x in batch.get_dependents(include_roots=False)]
)
for other_batch in other_batches:
if other_batch.decision != decision:
# Currently, this is not recursive, in the future it might be
other_batch_ids = {
d.object_id for d in other_batch.get_dependents(include_roots=False)
}
other_batch_ids = set(
[
d.object_id
for d in other_batch.get_dependencies(include_roots=True)
]
)
if len(other_batch_ids & skipped_or_ignored_ids) != 0:
other_batch.decision = decision
skipped_or_ignored_ids.update(other_batch_ids)
action = "Skipping" if decision == SyncDecision.skip else "Ignoring"
print(
f"{action} other batch with root {other_batch.root_type.__name__}"
f"\n{action} other batch with root {other_batch.root_type.__name__}\n"
)


def resolve(
state: NodeDiff,
decision: str | None = None,
decision: list[str] | str | None = None,
share_private_objects: bool = False,
ask_for_input: bool = True,
) -> tuple[ResolvedSyncState, ResolvedSyncState]:
# TODO: fix this
previously_ignored_batches = state.low_state.ignored_batches
# TODO: only add permissions for objects where we manually give permission
# Maybe default read permission for some objects (high -> low)
resolved_state_low = ResolvedSyncState(node_uid=state.low_node_uid, alias="low")
Expand All @@ -83,12 +88,14 @@ def resolve(

if batch_decision is None:
batch_decision = get_user_input_for_resolve()
batch_diff.decision = batch_decision
other_batches = [b for b in state.batches if b is not batch_diff]
handle_ignore_skip(batch_diff, batch_decision, other_batches)
else:
batch_decision = SyncDecision(batch_decision)

batch_diff.decision = batch_decision

other_batches = [b for b in state.batches if b is not batch_diff]
handle_ignore_skip(batch_diff, batch_decision, other_batches)

if batch_decision not in [SyncDecision.skip, SyncDecision.ignore]:
sync_instructions = get_sync_instructions_for_batch_items_for_add(
batch_diff,
Expand All @@ -100,6 +107,13 @@ def resolve(
sync_instructions = []
if batch_decision == SyncDecision.ignore:
resolved_state_high.add_skipped_ignored(batch_diff)
resolved_state_low.add_skipped_ignored(batch_diff)

if (
batch_diff.root_id in previously_ignored_batches
and batch_diff.decision != SyncDecision.ignore
):
sync_instruction.unignore = True

print(f"Decision: Syncing {len(sync_instructions)} objects")

Expand Down
51 changes: 50 additions & 1 deletion packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# relative
from ...client.sync_decision import SyncDecision
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syft_object import short_uid
Expand Down Expand Up @@ -744,6 +745,33 @@ def _hierarchy_str_recursive(tree: dict, level: int) -> str:
{res}"""


class IgnoredBatchView(SyftObject):
__canonical_name__ = "IgnoredBatchView"
__version__ = SYFT_OBJECT_VERSION_1
batch: ObjectDiffBatch
other_batches: list[ObjectDiffBatch]

def _coll_repr_(self) -> str:
return self.batch._coll_repr_()

def _repr_html_(self) -> str:
return self.batch._repr_html_()

def stage_change(self) -> None:
self.batch.decision = None
required_dependencies = {
d.object_id for d in self.batch.get_dependencies(include_roots=True)
}

for other_batch in self.other_batches:
if (
other_batch.decision == SyncDecision.ignore
and other_batch.root_id in required_dependencies
):
print(f"ignoring other batch ({other_batch.root_type.__name__})")
other_batch.decision = None


class NodeDiff(SyftObject):
__canonical_name__ = "NodeDiff"
__version__ = SYFT_OBJECT_VERSION_2
Expand All @@ -756,6 +784,17 @@ class NodeDiff(SyftObject):
low_state: SyncState
high_state: SyncState

@property
def ignored_changes(self):
ignored_batches = [b for b in self.batches if b.decision == SyncDecision.ignore]
result = []
for ignored_batch in ignored_batches:
other_batches = [b for b in self.batches if b is not ignored_batch]
result.append(
IgnoredBatchView(batch=ignored_batch, other_batches=other_batches)
)
return result

@classmethod
def from_sync_state(
cls: type["NodeDiff"],
Expand Down Expand Up @@ -826,12 +865,15 @@ def apply_previous_ignore_state(
if hash(batch) == batch_hash:
batch.decision = SyncDecision.ignore
else:
print(f"""A batch with type {batch.root_type.__name__} was previously ignored but has changed
It will be available for review again.""")
# batch has changed, so unignore
batch.decision = None
# then we also set the dependent batches to unignore
# currently we dont do this recusively
required_dependencies = {
d for deps in batch.dependencies.values() for d in deps
d.object_id
for d in batch.get_dependencies(include_roots=True)
}

for other_batch in batches:
Expand Down Expand Up @@ -949,6 +991,7 @@ class SyncInstruction(SyftObject):
new_permissions_lowside: list[ActionObjectPermission]
new_storage_permissions_lowside: list[StoragePermission]
new_storage_permissions_highside: list[StoragePermission]
unignore: bool = False
mockify: bool


Expand All @@ -963,6 +1006,9 @@ class ResolvedSyncState(SyftObject):
new_permissions: list[ActionObjectPermission] = []
new_storage_permissions: list[StoragePermission] = []
ignored_batches: dict[UID, int] = {} # batch root uid -> hash of the batch
unignored_batches: set[UID] = (
set()
) # NOTE: using '{}' as default value does not work here
alias: str

def add_skipped_ignored(self, batch: ObjectDiffBatch) -> None:
Expand All @@ -976,6 +1022,9 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
return
diff = sync_instruction.diff

if sync_instruction.unignore:
self.unignored_batches.add(sync_instruction.diff.object_id)

if diff.status == "SAME":
return

Expand Down
49 changes: 38 additions & 11 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import cast

# third party
from result import Ok
from result import Result

# relative
Expand Down Expand Up @@ -172,6 +173,7 @@ def sync_items(
permissions: list[ActionObjectPermission],
storage_permissions: list[StoragePermission],
ignored_batches: dict[UID, int],
unignored_batches: set[UID],
) -> SyftSuccess | SyftError:
permissions_dict = defaultdict(list)
for permission in permissions:
Expand Down Expand Up @@ -201,10 +203,11 @@ def sync_items(
else:
return SyftError(message=f"Failed to sync {res.err()}")

new_state = self.build_current_state(context, ignored_batches)
if isinstance(new_state, SyftError):
return new_state
res = self.build_current_state(context, ignored_batches, unignored_batches)
if res.is_err():
return SyftError(message=res.message)
else:
new_state = res.ok()
res = self.stash.set(context.credentials, new_state)
if res.is_err():
return SyftError(message=res.message)
Expand Down Expand Up @@ -234,7 +237,7 @@ def get_permissions(

def get_all_syncable_items(
self, context: AuthedServiceContext
) -> list[SyncableSyftObject] | SyftError:
) -> Result[list[SyncableSyftObject], str]:
node = cast(AbstractNode, context.node)
all_items = []

Expand Down Expand Up @@ -269,27 +272,47 @@ def get_all_syncable_items(
context, uid, resolve_nested=False
) # type: ignore
if action_object.is_err():
return SyftError(message=action_object.err())
return action_object
all_items.append(action_object.ok())

return all_items
return Ok(all_items)

def build_current_state(
self,
context: AuthedServiceContext,
new_ignored_batches: dict[UID, int] | None = None,
) -> SyncState | SyftError:
new_unignored_batches: set[UID] | None = None,
) -> Result[SyncState, str]:
new_ignored_batches = (
new_ignored_batches if new_ignored_batches is not None else {}
)
objects = self.get_all_syncable_items(context)
new_unignored_batches = (
new_unignored_batches if new_unignored_batches is not None else {}
)
objects_res = self.get_all_syncable_items(context)
if objects_res.is_err():
return objects_res
else:
objects = objects_res.ok()
permissions, storage_permissions = self.get_permissions(context, objects)

previous_state = self.stash.get_latest(context=context)
if previous_state.is_err():
return SyftError(message=previous_state.err())
return previous_state
previous_state = previous_state.ok()

new_ignored_batches = {
**previous_ignored_batches,
**new_ignored_batches,
}
new_ignored_batches = {
k: v
for k, v in new_ignored_batches.items()
if k not in new_unignored_batches
}
new_state.ignored_batches = new_ignored_batches
print("ignored batches new", new_state.ignored_batches)

if previous_state is not None:
previous_state_link = LinkedObject.from_obj(
obj=previous_state,
Expand All @@ -316,12 +339,16 @@ def build_current_state(

new_state.add_objects(objects, context)

return new_state
return Ok(new_state)

@service_method(
path="sync._get_state",
name="_get_state",
roles=ADMIN_ROLE_LEVEL,
)
def _get_state(self, context: AuthedServiceContext) -> SyncState | SyftError:
return self.build_current_state(context)
res = self.build_current_state(context)
if res.is_err():
return SyftError(message=res.value)
else:
return res.ok()

0 comments on commit 8e433e4

Please sign in to comment.