In [49]:
import pandas as pd
import numpy as np
import networkx as nx
import plotly.graph_objects as go
import plotly.offline as pyo

In [50]:
df = pd.read_csv('master_df.csv')
df.round(2)

Unnamed: 0.1,Unnamed: 0,[0-0],[1-0],[0-1],[2-0],[1-1],[0-2],[3-0],[2-1],[1-2],[3-1],[2-2],[3-2],OUT,WALK,PLAY
0,[0-0],0.0,0.38,0.51,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.11
1,[1-0],0.0,0.0,0.0,0.34,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.16
2,[0-1],0.0,0.0,0.0,0.0,0.4,0.43,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.17
3,[2-0],0.0,0.0,0.0,0.0,0.0,0.0,0.31,0.53,0.0,0.0,0.0,0.0,0.0,0.0,0.17
4,[1-1],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.34,0.46,0.0,0.0,0.0,0.0,0.0,0.2
5,[0-2],0.0,0.0,0.0,0.0,0.0,0.19,0.0,0.0,0.45,0.0,0.0,0.0,0.18,0.0,0.18
6,[3-0],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.62,0.0,0.0,0.0,0.33,0.05
7,[2-1],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.29,0.49,0.0,0.0,0.0,0.22
8,[1-2],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.22,0.0,0.38,0.0,0.19,0.0,0.21
9,[3-1],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.51,0.0,0.27,0.22


In [51]:
G = nx.DiGraph()

# Define edges between nodes that reflect state transition matrix
elist = [(1,2),(1,3),(2,4),(2,5),(3,5),(3,6),(4,7),(4,8),(5,8),(5,9),(6,6),(6,9),\
        (6,13),(7,10),(7,14),(8,10),(8,11),(9,9),(9,11),(9,13),(10,12),(10,14),\
        (11,11),(11,12),(11,13),(12,12),(12,13),(12,14)]

labels = ['0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%',\
          '0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%','0%']

edges_with_labels = [(elist[i][0], elist[i][1], {'label': labels[i]}) for i in range(len(elist))]

In [52]:
G.add_edges_from(edges_with_labels)
label_mapping = {1:'[0,0]',2:'[1,0]',3:'[0,1]',4:'[2,0]',5:'[1,1]',6:'[0,2]',\
                 7:'[3,0]',8:'[2,1]',9:'[1,2]',10:'[3,1]',11:'[2,2]',12:'[3,2]',\
                    13:'OUT',14:'WALK'}
G = nx.relabel_nodes(G, label_mapping)

In [53]:
pos = {
    '[0,0]': (150, 600),
    '[0,1]': (180, 500),
    '[1,0]': (120, 500),
    '[1,1]': (150, 400),
    '[0,2]': (210, 400),
    '[2,0]': (90, 400),
    '[1,2]': (180, 300),
    '[2,1]': (120, 300),
    '[2,2]': (150, 200),
    '[3,0]': (60, 300),
    '[3,1]': (90, 200),
    '[3,2]': (120, 100),
    'OUT': (180, 0),
    'WALK': (90, 0)
}

In [54]:
# Create edge trace
edge_trace = go.Scatter(
    x=[],
    y=[],
    line = dict(width=2, color='black'),
    hoverinfo='none',
    mode='lines')

edge_annotations = []

In [55]:
# Add edges to edge Trace
for edge in G.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_trace['x'] += tuple([x0, x1, None])
    edge_trace['y'] += tuple([y0, y1, None])

    if edge[0] == edge[1]:
        x_pos = x0+15
        y_pos = y0+5
    else:
        x_pos = (x0+x1)/2 - 5
        y_pos = (y0+y1)/2
    # Edge annotations (midpoint for label placement)
    edge_annotations.append(
        dict(
            x=x_pos,
            y=y_pos,
            xref='x',
            yref='y',
            text = str(G.edges[edge]['label']),
            showarrow=False,
            font=dict(color='black', size=10)
        )
    )

In [59]:
# Create node trace
node_trace = go.Scatter(
    x=[],
    y=[],
    mode='markers+text',
    hoverinfo='none',
    marker=dict(
        showscale=False,
        color=[],
        size=[],
        symbol='circle',
        line_width=1,
        line_color='black',
        opacity=1),
    text=[],
    textposition="middle center",
    textfont=dict(
        family="Arial",
        size=14,
        color="black")
)

node_diameter = 60
# Add nodes to node trace
for node in G.nodes():
    x,y = pos[node]
    node_trace['x'] += tuple([x])
    node_trace['y'] += tuple([y])
    # Customize node size and color
    node_trace['marker']['color'] += tuple(['lightgray'])
    node_trace['marker']['size'] += tuple([node_diameter])
    node_trace['text'] += tuple([str(node)])


In [60]:
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    width=600,
                    height=800,
                    title='<br>State Transition Graph',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    annotations=edge_annotations,
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor="y", scaleratio=3),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x", scaleratio=1)))

In [61]:
# Add loops to graph with arrow caps
loops = []
arrows = []

for i in range(4):
    loop = dict(
        start_pt = [pos[f'[{i},2]'][0]+8, pos[f'[{i},2]'][1]+20],
        end_pt = [pos[f'[{i},2]'][0]+10.5, pos[f'[{i},2]'][1]-7],
        control_pt1 = [pos[f'[{i},2]'][0]+40, pos[f'[{i},2]'][1]+23],
        control_pt2 = [pos[f'[{i},2]'][0]+15, pos[f'[{i},2]'][1]-25]
    )
    loops.append(loop)
    
    arrow = dict(
        type="path",
        path=f"M {loop['end_pt'][0]},{loop['end_pt'][1]-1} L {loop['end_pt'][0]+3},{loop['end_pt'][1]+1} L {loop['end_pt'][0]+2},{loop['end_pt'][1]-7} Z",
        fillcolor="Black",
        line=dict(color="Black")
    )
    arrows.append(arrow)

shapes = [
    dict(
        type='path',
        path=f"M {loop['start_pt'][0]}, {loop['start_pt'][1]} C {loop['control_pt1'][0]},{loop['control_pt1'][1]} {loop['control_pt2'][0]},{loop['control_pt2'][1]} {loop['end_pt'][0]},{loop['end_pt'][1]}",
        line_color="Black"
    ) for loop in loops
]

shapes.extend(arrows)

fig.update_layout(shapes=shapes)
fig.show()