# Action Graph 

### Goals:

- A graph store/database to store and trace any computations during eager execution
- Graph that works with the current in memory worker
- Ability to visualize the graph
- Generate a dependecy list of node, so that any dependeny action can be generated
- Basic query/search functionalities
- Locking/Concurrency

In [None]:
import syft as sy
from syft.service.action.action_graph_service import ActionGraphService, NodeActionDataUpdate, ExecutionStatus
from syft.service.action.action_graph import InMemoryActionGraphStore, InMemoryGraphConfig, InMemoryStoreClientConfig

from syft.service.context import AuthedServiceContext
from syft.node.credentials import SyftSigningKey
from syft.service.action.action_graph import Action
from syft.service.action.numpy import NumpyArrayObject, ActionObject
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

## Scenario for performing some computation

```python

import syft as sy

domain_client = sy.login("....")

dataset = domain_client.datasets[0]

a = dataset.assets["A"]

b = dataset.assets["B"]

c = a + b

d = domain_client.api.numpy.array([1, 2, 3])

e = c * d

# Inplace operation, mutated the value of d
d.astype('int32')

d[2] = 5

f = d + 48
```

<br>
<br>

**Corresponding Actions Generated**

```

action1 -> a + b

action2 -> initialization of variable `d`

action3 -> c * d

action4 -> inplace updation of type of `d` (d.astype('int32'))

action5 -> d[2] = 5

action6 -> d + 48

```

![graph.png](graph.png)

### Initializing the Store

In [None]:
# Create a Config

store_config = InMemoryGraphConfig()

In [None]:
# Initialize the InMemory Store

graph_store = InMemoryActionGraphStore(store_config=store_config, reset=True)

In [None]:
G = graph_store.graph.db

### Initializing Action Graph Service

In [None]:
action_graph_service = ActionGraphService(store=graph_store)

In [None]:
signing_key = SyftSigningKey.generate()
authed_context = AuthedServiceContext(credentials=signing_key.verify_key)

### Create some dummy data

In [None]:
labels_dict = {}

In [None]:
action_obj_a = ActionObject.from_obj([2, 4, 6])
action_obj_b = ActionObject.from_obj([2, 3, 4])

In [None]:
labels_dict[action_obj_a.id] = "A"
labels_dict[action_obj_b.id] = "B"

In [None]:
action_obj_a.id, action_obj_b.id

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_a)

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_b)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 2
assert len(action_graph_service.get_all_edges(authed_context)) == 0

### Action1 -> A + B

In [None]:
action1 = Action(
    path="action.execute",
    op="__add__",
    remote_self=action_obj_a.syft_lineage_id,
    args=[action_obj_b.syft_lineage_id],
    kwargs={}
)
action1

In [None]:
labels_dict[action1.id] = "+"
labels_dict[action1.result_id.id] = "C"

### Add the action to Graph

In [None]:
# action_graph_service.add_action

In [None]:
action_graph_service.add_action(context=authed_context, action=action1)

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, node_color="yellow", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 4
assert len(action_graph_service.get_all_edges(authed_context)) == 3

### Action2 -> np.array([1, 2, 3])

In [None]:
action_obj_d = ActionObject.from_obj([1, 2, 3])

In [None]:
labels_dict[action_obj_d.id] = "[1, 2 ,3]"

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=action_obj_d)

In [None]:
# Create Action2

action2 = Action(
    path="action.execute",
    op="np.array",
    remote_self=None,
    args=[action_obj_d.syft_lineage_id],
    kwargs={}
)
action2

In [None]:
labels_dict[action2.id] = "np.array"
labels_dict[action2.result_id.id] = "D"

In [None]:
# Save action to graph
action_graph_service.add_action(context=authed_context, action=action2)

In [None]:
# 747749f9494345b78e165f13351e52bf: {"data": NodeActionData()}

In [None]:
action1.result_id.id

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, node_color="yellow", node_size=800, font_size=22)

In [None]:
action_graph_service.store.graph.visualize()

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 7
assert len(action_graph_service.get_all_edges(authed_context)) == 5

### Action3 -> C * D

In [None]:
action3 = Action(
    path="action.execute",
    op="__mul__",
    remote_self=action1.result_id,
    args=[action2.result_id],
    kwargs={}
)
action3

In [None]:
action_graph_service.add_action(context=authed_context, action=action3)

In [None]:
labels_dict[action3.id] = "*"
labels_dict[action3.result_id.id] = "E"

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794651)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, node_color="yellow", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 9
assert len(action_graph_service.get_all_edges(authed_context)) == 8

### Action4 -> Mutate type of D

In [None]:
as_type_action_obj = ActionObject.from_obj('np.int32')

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=as_type_action_obj)

In [None]:
labels_dict[as_type_action_obj.id] = "np.int32"

In [None]:
action4 = Action(
    path="action.execute",
    op="astype",
    remote_self=action2.result_id,
    args=[as_type_action_obj.syft_lineage_id],
    kwargs={},
    result_id=action2.result_id
)
action4

In [None]:
action_graph_service.add_action(context=authed_context, action=action4)

In [None]:
labels_dict[action4.id] = "astype"

In [None]:
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, node_color="yellow", node_size=800, font_size=22)

In [None]:
mid = G.nodes(data=True)[action2.result_id.id]["data"]['next_mutagen_node']
mid

In [None]:
G.nodes(data=True)[action2.result_id.id]['data']

In [None]:
assert G.nodes(data=True)[mid]['data'].id == action4.id

#### Action5 -> D[2] = 5

In [None]:
idx_action_obj = ActionObject.from_obj(2)
action_graph_service.add_action_obj(context=authed_context, action_obj=idx_action_obj)

In [None]:
labels_dict[idx_action_obj.id] = "2"

In [None]:
item_val_action_obj = ActionObject.from_obj(5)
action_graph_service.add_action_obj(context=authed_context, action_obj=item_val_action_obj)

In [None]:
labels_dict[item_val_action_obj.id] = "5"

In [None]:
action5 = Action(
    path="action.execute",
    op="__setitem__",
    remote_self=action2.result_id,
    args=[idx_action_obj.syft_lineage_id, item_val_action_obj.syft_lineage_id],
    kwargs={},
    result_id=action2.result_id
)
action5

In [None]:
action_graph_service.add_action(context=authed_context, action=action5)

In [None]:
labels_dict[action5.id] = "__setitem__"

In [None]:
plt.figure(figsize=(20, 20))
pos = nx.spring_layout(G, seed=3113794652)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=2.0, node_color="orange", node_size=800, font_size=22)

In [None]:
assert len(action_graph_service.get_all_nodes(authed_context)) == 14
assert len(action_graph_service.get_all_edges(authed_context)) == 13


### Action6 -> D + 48

In [None]:
arg_action_obj = ActionObject.from_obj(48)

In [None]:
action_graph_service.add_action_obj(context=authed_context, action_obj=arg_action_obj)

In [None]:
labels_dict[arg_action_obj.id] = "48"

In [None]:
action6 = Action(
    path="action.execute",
    op="__add__",
    remote_self=action2.result_id,
    args=[arg_action_obj.syft_lineage_id],
    kwargs={},
)
action6

In [None]:
action_graph_service.add_action(context=authed_context, action=action6)

In [None]:
labels_dict[action6.id] = "+"
labels_dict[action6.result_id.id] = "F"

In [None]:
plt.figure(figsize=(40, 40))
pos = nx.spring_layout(G)
# nx.draw_networkx_nodes(G, pos=pos)
nx.draw_networkx(G, pos=pos, labels=labels_dict, with_labels=True, width=3.0, node_color="orange", node_size=3000, font_size=30)

In [None]:
labels_dict.values()

```

action1 -> a + b

action2 -> initialization of variable `d`

action3 -> c * d

action4 -> inplace updation of type of `d` (d.astype('int32'))

action5 -> d[0] = 10

action6 -> d + 48

```

## Filtering Actions in the Graph

### Filter by ExecutionStatus

```
ExecutionStatus
- PROCESSING
- DONE
- FAILED
```

In [None]:
action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING)

In [None]:
action1

In [None]:
action_graph_service.get_by_action_status(context=authed_context, status=ExecutionStatus.PROCESSING.DONE)

### Filter by Particular User

In [None]:
action_graph_service.get_by_verify_key(context=authed_context, verify_key=signing_key.verify_key)

In [None]:
### Serde 

In [None]:
bytes_data = sy.serialize(graph_store, to_bytes=True)

In [None]:
graph_store = sy.deserialize(bytes_data, from_bytes=True)

In [None]:
for x, y in zip(graph_store.graph.db.nodes(data=True), G.nodes(data=True)):
    uid1, node1 = x 
    uid2, node2 = y
    assert uid1==uid2
    assert node1['data'] == node2['data']