Skip to content

Commit

Permalink
Merge pull request #8634 from OpenMined/aziz/empty_highside
Browse files Browse the repository at this point in the history
fix always sharing to high side
  • Loading branch information
abyesilyurt committed Mar 27, 2024
2 parents 938c1b4 + a9a24ba commit eaca373
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 16 deletions.
7 changes: 5 additions & 2 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,11 @@ def get_sync_instructions_for_batch_items_for_add(
StoragePermission(uid=diff.object_id, node_uid=diff.low_node_uid)
]

# Always share to high_side
if diff.status == "NEW" and diff.high_obj is None:
if (
diff.status == "NEW"
and diff.high_obj is None
and decision == SyncDecision.low
):
new_storage_permissions_highside = [
StoragePermission(uid=diff.object_id, node_uid=diff.high_node_uid)
]
Expand Down
16 changes: 2 additions & 14 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import textwrap
from typing import Any
from typing import ClassVar
from typing import Literal

# third party
from pydantic import model_validator
Expand Down Expand Up @@ -246,7 +247,7 @@ def __hash__(self) -> int:
return hash(self.object_id) + hash(self.low_obj) + hash(self.high_obj)

@property
def status(self) -> str:
def status(self) -> Literal["NEW", "SAME", "DIFF"]:
if self.low_obj is None or self.high_obj is None:
return "NEW"
if len(self.diff_list) == 0:
Expand Down Expand Up @@ -448,19 +449,6 @@ def _wrap_text(text: str, width: int, indent: int = 4) -> str:
)


def _get_hierarchy_root(
diffs: list[ObjectDiff], dependencies: dict[UID, list[UID]]
) -> list[ObjectDiff]:
all_ids = {diff.object_id for diff in diffs}
child_ids = set()
for uid in all_ids:
child_ids.update(dependencies.get(uid, []))
# Root ids are object ids with no parent
root_ids = list(all_ids - child_ids)
roots = [diff for diff in diffs if diff.object_id in root_ids]
return roots


class ObjectDiffBatch(SyftObject):
__canonical_name__ = "DiffHierarchy"
__version__ = SYFT_OBJECT_VERSION_2
Expand Down
123 changes: 123 additions & 0 deletions packages/syft/tests/syft/service/sync/sync_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,129 @@ def compute() -> int:
assert high_items_to_sync.is_empty


def test_request_code_execution_multiple(worker, second_worker):
low_client = worker.root_client
client_low_ds = worker.guest_client
high_client = second_worker.root_client

@sy.syft_function_single_use()
def compute() -> int:
return 42

compute.code = dedent(compute.code)

@sy.syft_function_single_use()
def compute_twice() -> int:
return 42 * 2

compute_twice.code = dedent(compute_twice.code)

@sy.syft_function_single_use()
def compute_thrice() -> int:
return 42 * 3

compute_thrice.code = dedent(compute_thrice.code)

_ = client_low_ds.code.request_code_execution(compute)
_ = client_low_ds.code.request_code_execution(compute_twice)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state, decision="low", share_private_objects=True
)

assert not diff_state.is_same
assert len(diff_state.diffs) % 2 == 0
assert not low_items_to_sync.is_empty
assert not high_items_to_sync.is_empty

low_client.apply_state(low_items_to_sync)
high_client.apply_state(high_items_to_sync)

_ = client_low_ds.code.request_code_execution(compute_thrice)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state, decision="low", share_private_objects=True
)

assert not diff_state.is_same
assert len(diff_state.diffs) % 3 == 0
assert not low_items_to_sync.is_empty
assert not high_items_to_sync.is_empty


def test_sync_high(worker, second_worker):
low_client = worker.root_client
client_low_ds = worker.guest_client
high_client = second_worker.root_client

@sy.syft_function_single_use()
def compute() -> int:
return 42

compute.code = dedent(compute.code)

_ = client_low_ds.code.request_code_execution(compute)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state,
decision="high",
)

assert not diff_state.is_same
assert not low_items_to_sync.is_empty
assert high_items_to_sync.is_empty


@pytest.mark.parametrize(
"decision",
["skip", "ignore"],
)
def test_sync_skip_ignore(worker, second_worker, decision):
low_client = worker.root_client
client_low_ds = worker.guest_client
high_client = second_worker.root_client

@sy.syft_function_single_use()
def compute() -> int:
return 42

compute.code = dedent(compute.code)

_ = client_low_ds.code.request_code_execution(compute)

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state,
decision=decision,
)

assert not diff_state.is_same
assert low_items_to_sync.is_empty
assert high_items_to_sync.is_empty


@pytest.mark.parametrize(
"decision",
["skip", "ignore", "low", "high"],
)
def test_sync_empty(worker, second_worker, decision):
low_client = worker.root_client
high_client = second_worker.root_client

diff_state = compare_clients(low_client, high_client)
low_items_to_sync, high_items_to_sync = resolve(
diff_state,
decision=decision,
)

assert diff_state.is_same
assert low_items_to_sync.is_empty
assert high_items_to_sync.is_empty


@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows")
@pytest.mark.flaky(reruns=3, reruns_delay=3)
def test_sync_flow_no_sharing():
Expand Down

0 comments on commit eaca373

Please sign in to comment.