# Action Graph 

### Goals:

- A graph store/database to store and trace any computations during eager execution
- Graph that works with the current in memory worker
- Ability to visualize the graph
- Generate a dependecy list of node, so that any dependeny action can be generated
- Basic query/search functionalities
- Locking/Concurrency

In [None]:
import syft as sy
from syft.service.action.action_graph_service import ActionGraphService, NodeActionDataUpdate, ExecutionStatus
from syft.service.action.action_graph import InMemoryActionGraphStore, InMemoryGraphConfig, InMemoryStoreClientConfig
from syft.service.context import AuthedServiceContext
from syft.node.credentials import SyftSigningKey
from syft.service.action.action_graph import Action
from syft.service.action.numpy import NumpyArrayObject, ActionObject

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

## Scenario for performing some computation

```python

import syft as sy

domain_client = sy.login("....")

dataset = domain_client.datasets[0]

a = dataset.assets["A"]

b = dataset.assets["B"]

c = a + b

d = domain_client.api.numpy.array([1, 2, 3])

e = c * d

# Inplace operation, mutated the value of d
d.astype('int32')

d[2] = 5

f = d + 48
```

<br>
<br>

**Corresponding Actions Generated**

```

action1 -> a + b

action2 -> initialization of variable `d`

action3 -> c * d

action4 -> inplace updation of type of `d` (d.astype('int32'))

action5 -> d[2] = 5 (__set_item__) 

action6 -> d + 48

```
<br>
<br>

**There are 2 types of nodes in the graph: `action_object_node` and `action_node`. Corresponding Nodes Generated in the Action Graph will be:**

```
node1  -> action_object_node(a)
node2  -> action_object_node(b)
node3  -> action_node for the add action (action1)
node4  -> action_object_node(c) - automatically generated
node5  -> action_object_node([1,2,3])
node6  -> action_node for the np.array action (action2)
node7  -> action_object_node(d) - automatically generated
node8  -> action_node for the multiply action (action3)
node9  -> action_object_node(e)
node10 -> action_object_node(int32)
node11 -> action_node for the astype action (action4)
node12 -> action_object_node(index=2)
node13 -> action_object_node(value=2)
node14 -> action_node for the __set_item__ action (action5)
node15 -> action_object_node(48)
node16 -> action_node for the add action (action6)
node17 -> action_object_node(f)
```

In [None]:
node = sy.orchestra.launch(name="test-domain-1", dev_mode=True, reset=True)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")
domain_client

### Initializing the Store

In [None]:
# Create a Config
store_config = InMemoryGraphConfig()
# Initialize the InMemory Store
graph_store = InMemoryActionGraphStore(store_config=store_config, reset=True)
# Get the networkx graph
G = graph_store.graph.db

### Initializing Action Graph Service

In [None]:
action_graph_service = ActionGraphService(store=graph_store)

In [None]:
signing_key = SyftSigningKey.generate()
authed_context = AuthedServiceContext(node=node.python_node,credentials=signing_key.verify_key)

In [None]:
authed_context.node

### Create some dummy data

In [None]:
labels_dict = {}

In [None]:
action_obj_a = ActionObject.from_obj([2, 4, 6])
action_obj_b = ActionObject.from_obj([2, 3, 4])

In [None]:
labels_dict[action_obj_a.id] = "A"
labels_dict[action_obj_b.id] = "B"

In [None]:
action_obj_a.id, action_obj_b.id

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_a)

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_b)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 2
assert len(action_graph_service.get_all_edges(authed_context)) == 0

### Action1 -> A + B

In [None]:
action1 = Action(
    path="action.execute",
    op="__add__",
    remote_self=action_obj_a.syft_lineage_id,
    args=[action_obj_b.syft_lineage_id],
    kwargs={}
)
action1

In [None]:
labels_dict[action1.id] = "+"
labels_dict[action1.result_id.id] = "C"

#### Add the action1 to Graph

In [None]:
# action_graph_service.add_action

In [None]:
action_graph_service.add_action(context=authed_context, action=action1)

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, 
                 width=2.0, node_color="orange", node_size=800, font_size=22)

In [None]:
action_graph_service.store.graph.visualize()

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 4
assert len(action_graph_service.get_all_edges(authed_context)) == 3

### Action2 -> np.array([1, 2, 3])

In [None]:
action_obj_d = ActionObject.from_obj([1, 2, 3])

In [None]:
labels_dict[action_obj_d.id] = "[1, 2 ,3]"

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_d)

In [None]:
# Create Action2

action2 = Action(
    path="action.execute",
    op="np.array",
    remote_self=None,
    args=[action_obj_d.syft_lineage_id],
    kwargs={}
)
action2

In [None]:
labels_dict[action2.id] = "np.array"
labels_dict[action2.result_id.id] = "D"

In [None]:
# Save action to graph
np_array_node, d_node = action_graph_service.add_action(context=authed_context, action=action2)

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, 
                 node_color="orange", node_size=800, font_size=22)

In [None]:
action_graph_service.store.graph.visualize()

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 7
assert len(action_graph_service.get_all_edges(authed_context)) == 5

### Action3 -> C * D

In [None]:
action3 = Action(
    path="action.execute",
    op="__mul__",
    remote_self=action1.result_id,
    args=[action2.result_id],
    kwargs={}
)
action3

In [None]:
mul_action_node, _ = action_graph_service.add_action(context=authed_context, action=action3)

In [None]:
labels_dict[action3.id] = "*"
labels_dict[action3.result_id.id] = "E"

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794651)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, 
                 node_color="orange", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 9
assert len(action_graph_service.get_all_edges(authed_context)) == 8

Check if the `*` action is the child of `D`

In [None]:
assert action_graph_service.store.is_parent(parent=d_node.id, child=mul_action_node.id).ok() == True

### Action4 -> Mutate type of D

Let's look at the `d_node` before mutation

In [None]:
d_node

Now mutate it

In [None]:
as_type_action_obj = ActionObject.from_obj('np.int32')

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=as_type_action_obj)

In [None]:
labels_dict[as_type_action_obj.id] = "np.int32"

In [None]:
action4 = Action(
    path="action.execute",
    op="astype",
    remote_self=action2.result_id,
    args=[as_type_action_obj.syft_lineage_id],
    kwargs={},
    result_id=action2.result_id
)
action4

In [None]:
astype_node, _ = action_graph_service.add_action(context=authed_context, action=action4)

In [None]:
labels_dict[action4.id] = "astype"

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, 
                 width=2.0, node_color="orange", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 11
assert len(action_graph_service.get_all_edges(authed_context)) == 10

The `d_node` is updated after mutation happens

In [None]:
d_node

`is_mutagen` indicates that a node causes mutation. `is_mutaged` indicates that a node is mutated (by a mutagen node). `last_nm_mutagen_node` indicates the last mutagen node in a mutation chain

In [None]:
assert d_node.is_mutated == True
assert astype_node.is_mutagen == True
assert d_node.next_mutagen_node == astype_node.id
assert d_node.last_nm_mutagen_node == astype_node.id

### Action5 -> D[2] = 5

Another mutation of the node D

In [None]:
idx_action_obj = ActionObject.from_obj(2)
action_graph_service.add_action_obj(context=authed_context, action_obj=idx_action_obj)

In [None]:
labels_dict[idx_action_obj.id] = "idx=2"

In [None]:
item_val_action_obj = ActionObject.from_obj(5)
action_graph_service.add_action_obj(context=authed_context, action_obj=item_val_action_obj)

In [None]:
labels_dict[item_val_action_obj.id] = "val=5"

In [None]:
action5 = Action(
    path="action.execute",
    op="__setitem__",
    remote_self=action2.result_id,
    args=[idx_action_obj.syft_lineage_id, item_val_action_obj.syft_lineage_id],
    kwargs={},
    result_id=action2.result_id
)
action5

In [None]:
set_item_node, _ = action_graph_service.add_action(context=authed_context, action=action5)

In [None]:
labels_dict[action5.id] = "__setitem__"

In [None]:
plt.figure(figsize=(35, 20))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, 
                 node_color="orange", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 14
assert len(action_graph_service.get_all_edges(authed_context)) == 13

Let's look at `d_node` after the second mutation

In [None]:
d_node

The `last_nm_mutagen_node` of `d_node` becomes the `set_item` node

In [None]:
assert d_node.is_mutated == True
assert set_item_node.is_mutagen == True
assert d_node.next_mutagen_node == astype_node.id
assert d_node.last_nm_mutagen_node == set_item_node.id


### Action6 -> D + 48

In [None]:
arg_action_obj = ActionObject.from_obj(48)

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=arg_action_obj)

In [None]:
labels_dict[arg_action_obj.id] = "48"

In [None]:
action6 = Action(
    path="action.execute",
    op="__add__",
    remote_self=action2.result_id,
    args=[arg_action_obj.syft_lineage_id],
    kwargs={},
)
action6

In [None]:
action6_node, f_node = action_graph_service.add_action(context=authed_context, action=action6)

In [None]:
labels_dict[action6.id] = "+"
labels_dict[action6.result_id.id] = "F"

In [None]:
plt.figure(figsize=(60, 40))
pos = nx.spring_layout(G, seed=3113794652)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, 
                 width=3.0, node_color="orange", node_size=3000, font_size=30)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 17
assert len(action_graph_service.get_all_edges(authed_context)) == 16

The final add action node will be the parent of the `__set_item__` node since `__set_item__` is the final mutation of the node d

In [None]:
assert action_graph_service.store.is_parent(parent=set_item_node.id, child=action6_node.id).ok() == True

## Filtering Actions in the Graph

### Filter by ExecutionStatus

```
ExecutionStatus
- PROCESSING
- DONE
- FAILED
```

In [None]:
action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING)

In [None]:
action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING.DONE)

In [None]:
assert len(action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING))==17
assert len(action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.DONE))==0

Let's change an action's status to be DONE

In [None]:
d_node.status = ExecutionStatus.DONE

In [None]:
assert len(action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING))==16
assert len(action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.DONE))==1

### Filter by Particular User

In [None]:
assert len(action_graph_service.get_by_verify_key(context=authed_context, verify_key=signing_key.verify_key))==17

## Serde 

In [None]:
bytes_data = sy.serialize(action_graph_service, to_bytes=True)

In [None]:
action_graph_service_back = sy.deserialize(bytes_data, from_bytes=True)

In [None]:
assert action_graph_service.get_all_nodes(authed_context) == action_graph_service_back.get_all_nodes(authed_context)

In [None]:
assert action_graph_service.get_all_edges(authed_context) == action_graph_service_back.get_all_edges(authed_context)