In [1]:
# stdlib
from copy import deepcopy
from datetime import datetime
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
from typing import Type

# third party
from pydantic import Field

# syft absolute
from syft import UID
from syft.types.base import SyftBaseModel

In [2]:
class MyBaseModel(SyftBaseModel):
    id: UID = Field(default_factory=lambda: UID())

In [3]:
class Event(MyBaseModel):
    creator: UID
    creation_date: datetime = Field(default_factory=lambda: datetime.now())

    def handler(self, node):
        method_name = event_handler_registry[self.__class__.__name__]
        return getattr(node, method_name)

In [4]:
class EventLog(MyBaseModel):
    log: List[Event] = []

In [5]:
class LinkedObject(MyBaseModel):
    node_id: UID
    obj_id: UID

In [6]:
class Dataset(MyBaseModel):
    real: LinkedObject
    mock: str
    description: str

In [7]:
class UserCode(MyBaseModel):
    code: str
    approved: bool = False

In [8]:
def register_event_handler(event_type):
    def inner(method):
        event_handler_registry[event_type.__name__] = method.__name__
        return method

    return inner

In [9]:
event_handler_registry = {}

In [10]:
# class CUDObjectEvent(Event):
#     object_type: ClassVar[Type]
# #     object_type: Optional[ClassVar[Type]

In [11]:
class CRUDEvent(Event):
    object_type: ClassVar[Type] = Type
    object_id: UID

    @property
    def merge_updates_repr(self):
        return f"{self.updates} for object {self.object_id} by {self.creator}"

In [12]:
class CreateObjectEvent(CRUDEvent):
    @property
    def updated_properties(self):
        return list(self.object_type.__annotations__.keys())

    @property
    def updates(self):
        return {p: getattr(self, p) for p in self.updated_properties}

    @property
    def update_tuples(self):
        return list(self.updates.items())


class UpdateObjectEvent(CRUDEvent):
    updates: Dict[str, Any]

    @property
    def updated_properties(self):
        return list(self.updates.keys())

    @property
    def update_tuples(self):
        return list(self.updates.items())

# Events

In [13]:
class CreateDatasetEvent(CreateObjectEvent):
    object_type: ClassVar[Type] = Dataset
    mock: Any
    real: LinkedObject
    description: str
    creator: UID

    def execute(self, node):
        handler = self.handler(node)
        handler(
            object_id=self.real.obj_id,
            mock=self.mock,
            real=self.real,
            description=self.description,
        )

In [14]:
class UpdateDatasetEvent(UpdateObjectEvent):
    object_type: ClassVar[Type] = Dataset
    object_id: UID

    def execute(self, node):
        handler = self.handler(node)
        handler(object_id=self.object_id, updates=self.updates)

In [15]:
class CreateUserCodeEvent(CreateObjectEvent):
    object_type: ClassVar[Type] = UserCode
    code: UserCode

    def execute(self, node):
        handler = self.handler(node)
        handler(code=self.code)

In [16]:
class ApproveUserCodeEvent(Event):
    object_type: ClassVar[Type] = UserCode
    code_id: UID
    value: bool

    def execute(self, node):
        handler = self.handler(node)
        handler(self.code_id, self.value)

# Node

In [17]:
class Node(MyBaseModel):
    event_log: EventLog = EventLog()
    store: Dict[UID, Any] = {}
    private_store: Dict[UID, Any] = {}

    def apply_log(self, log):
        self.store = {}
        self.event_log = deepcopy(log)

        for event in self.event_log.log:
            event.execute(self)

    def create_usercode(self, usercode: str):
        obj = UserCode(code=usercode)
        event = CreateUserCodeEvent(code=obj, object_id=obj.id, creator=self.id)
        self.event_log.log.append(event)

        self._create_usercode(obj)
        return obj.id

    @register_event_handler(CreateUserCodeEvent)
    def _create_usercode(self, code):
        self.store[code.id] = code

    def approve_usercode(self, code_id: UID, to: bool):
        event = ApproveUserCodeEvent(code_id=code_id, creator=self.id, value=to)
        self.event_log.log.append(event)
        self._approve_usercode(code_id, to)

    @register_event_handler(ApproveUserCodeEvent)
    def _approve_usercode(self, code_id, to):
        self.store[code_id].approved = to

    def create_dataset(self, mock: str, real: Optional[str], description: str):
        object_id = UID()
        real_id = UID()
        real_obj = LinkedObject(node_id=self.id, obj_id=object_id)

        self.private_store[real_id] = real

        event = CreateDatasetEvent(
            object_id=object_id,
            mock=mock,
            real=real_obj,
            description=description,
            creator=self.id,
        )

        self.event_log.log.append(event)
        self._create_dataset(object_id, mock, real_obj, description)

    @register_event_handler(CreateDatasetEvent)
    def _create_dataset(self, object_id, mock, real, description):
        dataset = Dataset(id=object_id, mock=mock, real=real, description=description)
        self.store[dataset.id] = dataset

    def update_dataset(self, id, updates):
        event = UpdateDatasetEvent(object_id=id, updates=updates, creator=self.id)
        self.event_log.log.append(event)
        self._update_dataset(id, updates)

    @register_event_handler(UpdateDatasetEvent)
    def _update_dataset(self, object_id, updates):
        dataset = self.store[object_id]

        for k, v in updates.items():
            setattr(dataset, k, v)

        self.store[object_id] = dataset

we want to check for 'mutations' of the same object, which is defined as:

- CUD (from CRUD) of objects with the same unique keys
  - create changes all attributes
  - delete changes all attributes
  - update only changes the attributes that were updated


In the case of update, if only non overlapping sets of properties were updated its not a merge conflict, as long as those are not code approval mutations.

# MergeState

In [18]:
class MergeState(SyftBaseModel):
    proposed_merge: List[Event]
    fork_idx: int
    new_log: List[Event] = []

    @property
    def current_merge_events(self):
        return self.new_log[self.fork_idx :]

    #     @property
    #     def updates_since_fork(self):
    #         updates_since_fork: Dict[UID, List[Event]] = defaultdict(list)
    #         # {node_id -> {obj_id -> Event}}
    #         for event in self.new_events:
    #             updates_since_fork[e.creator] += [event]
    #         return updates_since_fork

    def merge(self):
        self.new_log = self.proposed_merge[: self.fork_idx]
        for event in self.proposed_merge[self.fork_idx :]:
            if self.add_event(event):
                print("merge conflict")

    def request_input(self, event, conflicting_event):
        s = input(
            f"""
            {event.object_id} was changed by {event.creator} and {conflicting_event.creator}
            Change 0: {event.merge_updates_repr}
            Change 1: {conflicting_event.merge_updates_repr}
            Type 0/1 to keep the corresponding change
            """
        )
        idx = int(s)
        assert idx in [0, 1]
        return idx == 1

    def object_updates(self, object_id, exclude_node: UID):
        #         other_node_ids = [node_id for node_id in self.updates_since_fork.keys()
        #                           if node_id != event.creator]
        #         other_events_updating_object = [e for i in other_node_ids for e in self.updates_since_fork[i]
        #                                         if e.object_id == object_id]

        other_events_updating_object = [
            e
            for e in self.current_merge_events
            if e.object_id == object_id and e.creator != exclude_node
        ]

        object_updates = {}

        for e in other_events_updating_object:
            for p in e.updated_properties:
                val = e.updates[p]
                object_updates[p] = (val, e)

        return object_updates

    def add_event(self, event):
        merge_object_updates = self.object_updates(
            event.object_id, exclude_node=event.creator
        )
        # we want to find all the events from other nodes that updated the same object
        # then we want to find which properties they updated and to what value
        # if they updated the same property to a different value => merge conflict
        # (property, value) => event

        skip_current_event = False
        for prop, val in event.updates.items():
            if skip_current_event:
                continue
            # val -> event
            if prop not in merge_object_updates:
                continue

            other_val, other_event = merge_object_updates[prop]
            if other_val != val:
                conflicting_event = other_event
                skip_current_event = self.request_input(event, conflicting_event)
                skip_conflicting_event = not skip_current_event

                # merge strategies:
                # accept entire event, reject other event entirely
                # cherry pick per property

                if skip_conflicting_event:
                    print("skip conflicting event")
                    # remove conflicting event from new_log
                    self.new_log = [
                        e for e in self.new_log if e.id != conflicting_event.id
                    ]

        if not skip_current_event:
            self.new_log += [event]

# Sync

In [19]:
def sync(node_high, now_low):
    log1 = node_high.event_log.log
    log2 = node_low.event_log.log

    # find idx of the fork
    fork_idx = max(len(log1), len(log2))
    for i, (e1, e2) in enumerate(list(zip(log1, log2))):
        if e1.id != e2.id:
            fork_idx = i
            break

    branch1 = log1[fork_idx:]
    branch2 = log2[fork_idx:]

    proposed_merge = log1[:fork_idx] + sorted(
        branch1 + branch2, key=lambda e: e.creation_date
    )
    print(f"proposed merge (before merging): {proposed_merge}")
    merge_state = MergeState(fork_idx=fork_idx, proposed_merge=proposed_merge)
    merge_state.merge()

    new_log = EventLog(log=merge_state.new_log)

    node_low.apply_log(new_log)
    node_high.apply_log(new_log)

    assert all(
        [x == y for x, y in zip(node_low.event_log.log, node_high.event_log.log)]
    ) and len(node_low.event_log.log) == len(node_high.event_log.log)

# Sync 1: create dataset and sync

In [20]:
node_high = Node()
node_low = Node()

In [21]:
node_high.create_dataset(real="abc", mock="def", description="blabla")

In [22]:
node_high.event_log.log

In [23]:
sync(node_high, node_low)

proposed merge (before merging): [CreateDatasetEvent(id=<UID: 82f8e63aa515456483f6de9cf6eda223>, creator=<UID: 5905d8a7d328416c81cd114f8cc0f060>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 140729), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, mock='def', real=LinkedObject(id=<UID: c36ca6511ccf42b094836640df93b1a5>, node_id=<UID: 5905d8a7d328416c81cd114f8cc0f060>, obj_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>), description='blabla')]


In [24]:
node_high.event_log.log

In [25]:
assert node_high.store.keys() == node_low.store.keys()

# Sync 2: both update same property to same value

In [26]:
dataset = list(node_high.store.values())[0]

In [27]:
node_high.update_dataset(dataset.id, {"description": "a"})

In [28]:
node_high.event_log.log

In [29]:
node_low.update_dataset(dataset.id, {"description": "a"})

In [30]:
node_high.event_log.log

In [31]:
# node_low.event_log.log

In [32]:
sync(node_high, node_low)

proposed merge (before merging): [CreateDatasetEvent(id=<UID: 82f8e63aa515456483f6de9cf6eda223>, creator=<UID: 5905d8a7d328416c81cd114f8cc0f060>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 140729), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, mock='def', real=LinkedObject(id=<UID: c36ca6511ccf42b094836640df93b1a5>, node_id=<UID: 5905d8a7d328416c81cd114f8cc0f060>, obj_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>), description='blabla'), UpdateDatasetEvent(id=<UID: abbf245d31ce44b487a505e379c4645e>, creator=<UID: 5905d8a7d328416c81cd114f8cc0f060>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 172952), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, updates={'description': 'a'}), UpdateDatasetEvent(id=<UID: 6b698e9fccd04f2499a0bcdb588ff49d>, creator=<UID: 9a1f1170f50546a68f11734b9e60b534>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 182961), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, updates={'description': 'a'})]


In [33]:
dataset_high = list(node_high.store.values())[0]
dataset_low = list(node_low.store.values())[0]

In [34]:
assert dataset_high.description == dataset_low.description

In [35]:
node_high.event_log.log

In [36]:
node_low.event_log.log

In [37]:
# we keep both events
assert len(node_high.event_log.log) == 3 and len(node_low.event_log.log) == 3

# Sync 3: both update same property to different value

In [38]:
dataset = list(node_high.store.values())[0]

In [39]:
# node_low.event_log.log

In [40]:
if False:
    node_high.update_dataset(dataset.id, {"description": "b"})

    node_high.event_log.log

    node_low.update_dataset(dataset.id, {"description": "c"})

    node_high.event_log.log

    sync(node_high, node_low)

    dataset_high = list(node_high.store.values())[0]
    dataset_low = list(node_low.store.values())[0]

    assert dataset_high.description == dataset_low.description

    node_high.event_log.log

    assert len(node_high.event_log.log) == 4 and len(node_low.event_log.log) == 4

# Sync 4: UserCode

In [41]:
user_code_id = node_low.create_usercode("print('a')")

In [42]:
sync(node_low, node_high)

proposed merge (before merging): [CreateDatasetEvent(id=<UID: 82f8e63aa515456483f6de9cf6eda223>, creator=<UID: 5905d8a7d328416c81cd114f8cc0f060>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 140729), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, mock='def', real=LinkedObject(id=<UID: c36ca6511ccf42b094836640df93b1a5>, node_id=<UID: 5905d8a7d328416c81cd114f8cc0f060>, obj_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>), description='blabla'), UpdateDatasetEvent(id=<UID: abbf245d31ce44b487a505e379c4645e>, creator=<UID: 5905d8a7d328416c81cd114f8cc0f060>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 172952), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, updates={'description': 'a'}), UpdateDatasetEvent(id=<UID: 6b698e9fccd04f2499a0bcdb588ff49d>, creator=<UID: 9a1f1170f50546a68f11734b9e60b534>, creation_date=datetime.datetime(2024, 1, 23, 13, 0, 35, 182961), object_id=<UID: 8031ca8c83da4217bdc7e2f51fdb0dca>, updates={'description': 'a'}), CreateUserCodeEvent(id=<U

# Sync 4: Approve UserCode

In [43]:
node_high.approve_usercode(user_code_id, True)

KeyError: <UID: 315d9b9b704b46b2bdf5379f389de967>

In [None]:
# TODO: is this result valid?
sync(node_low, node_high)

In [None]:
node_low.event_log.log

In [None]:
node_high.event_log.log

# Scenario list


- create a dataset and sync
  - should create the dataset object on both sides
- both update the same property (conflict)
- both update a different property (no conflict)
- code approval should have same state
- code execution should be approved