In [282]:
import numpy as np
import pandas as pd
import panel as pn
import networkx as nx
import plotly.graph_objs as go
import plotly.express as px

import ipywidgets as widgets
from IPython.display import display

import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State

import asyncio
import threading

from itertools import combinations

# Testing Visualizations Using a Toy Network

In [283]:
# Helper Functions

# changing node size based on activation (sum of connection weights)
# Default activation is -0.1
def get_activation(graph, rest_act=-0.1):
    activations = {node: rest_act for node in graph.nodes()}
    for u,v,d in graph.edges(data=True):
        activations[u] += d["weight"]
        if u != v:
            activations[v] += d["weight"]
    return activations, rest_act

# Map activations to marker sizes (e.g. between 20 and 50)
# Makes all nodes 35 if equal
def activation_to_size(act, low_bound=20, high_bound=50):
    max_act = max(act.values())
    min_act = min(act.values())
    if max_act == min_act:
        return (low_bound + high_bound)//2
    dynam_sizes = [low_bound + (x-min_act) * (high_bound-low_bound)/(max_act-min_act)
                   for x in act.values()]
    return dynam_sizes

# Map connection strengths to linewidths (e.g. between 1 and 5)
# Makes all nodes 2 if equal
def connection_to_lw(edges, low_bound=1, high_bound=5):
    connections = [np.abs(d['weight']) for u,v,d in edges]
    min_con = min(connections)
    max_con = max(connections)
    if min_con == max_con:
        return [(low_bound + high_bound)//2]*len(edges)
    dynam_lw = [low_bound + (x-min_con) * (high_bound-low_bound)/(max_con-min_con) 
                for x in connections]
    return dynam_lw
    

## Dropdown Menu

Pros:
- Different colors for positive/negative interactions
- Different sizes for nodes based on their activation
- Different linewidths depending on strength of connection
- You can view the activation and the net input by hovering over a node
- Different color for selected node
- Only highlights the connections to the selected node

Cons:
- Current inability to select more than one node
- No current ability to click on the nodes for selection (dropdown navigation would be difficult with a larger network)
- No grouping of the nodes positionally

In [316]:
pn.extension("plotly")

# Create toy network (graph)
pools = {'Group1: A':'Group1', 'Group1: B':'Group1', 'Group1: C':'Group1',
         'Group2: D':'Group2', 'Group2: E':'Group2'}
nodes = pools.keys()
edges = [('Group1: A', 'Group1: B', -2), ('Group1: A', 'Group1: C', -5), ('Group1: B', 'Group1: C', -3), ('Group1: C', 'Group2: D', 1), ('Group1: A', 'Group2: D', 1)]
pos = {
    'Group1: A': (0, 0),
    'Group1: B': (0, -1),
    'Group1: C': (-1, 0),
    'Group2: D': (0, 1),
    'Group2: E': (1, 0),
}
pos_map = {v:k for k,v in pos.items()}

G = nx.Graph()
G.add_nodes_from(nodes)
for u, v, weight in edges:
    G.add_edge(u, v, weight=weight)
    
# node color palette (can change to another palette)
col_pal = px.colors.qualitative.Plotly

# Function to create the plot
def create_plot(hover_node=None):
    # creating figure
    fig = go.Figure()
    
    # getting node sizes, node colors, connection linewidths
    activations, rest_act = get_activation(G)
    sizes = activation_to_size(activations)
    colors = [col_pal[list(set(pools.values())).index(pools[n])] for n in G.nodes()]
    linewidths = connection_to_lw(G.edges(data=True))
    
    # setting up node outline colors
    outline_colors = colors.copy()
    
    if hover_node is not None:
        hidden_ind = list(G.nodes()).index(hover_node)
        for edge_group, line_width in zip(G.edges(data=True), linewidths):
            u,v,d=edge_group
            x, y = [], []
            if hover_node in {u,v}:
                x += [pos[u][0], pos[v][0], None]
                y += [pos[u][1], pos[v][1], None]
            
                # connection line color
                line_color = 'blue' if pools[u] != pools[v] else 'red'
            
                # node outline colors
                adjust_ind = list(G.nodes()).index(u) if hover_node != u else list(G.nodes()).index(v)
                outline_colors[adjust_ind] = line_color
                outline_colors[hidden_ind] = 'black' # selected node line color
                colors[hidden_ind] = 'yellow' # selected node fill color
            
                fig.add_trace(go.Scatter(
                            x=x, y=y,
                            line=dict(width=line_width, color=line_color),
                            hoverinfo='none',
                            mode='lines',
                            showlegend=False,)
                          )
    
    # Adding Nodes
    node_trace = go.Scatter(
        x=[pos[n][0] for n in G.nodes()],
        y=[pos[n][1] for n in G.nodes()],
        text=[f"{n.split(' ')[1]}" for n in G.nodes()],
        marker=dict(size=sizes, 
                    color=colors,
                    line=dict(width=2, color=outline_colors)
                   ),
        hoverinfo="text",
        hovertext=[f"Activation: {activations[n]}<br>Net Input: {activations[n]-rest_act}" for n in G.nodes()],
        mode="markers+text",
        showlegend=False)
    fig.add_trace(node_trace)
    
    # Adding Legend Labels (dummy traces)
    fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="blue"),
        mode="lines",
        name="excitatory",
        showlegend=True
    ))
    fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="red"),
        mode="lines",
        name="inhibitory",
        showlegend=True
    ))
    
    # removing grid and axes
    fig.update_layout(
        showlegend=True,
        hovermode="closest",
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        width=700,
        height=700,
        plot_bgcolor="white",
        title = "Interactive Activation and Competition Network"
    )

    return fig

def view(hover_node=None):
    return create_plot(hover_node)

# Creating dropdown menu to select hover_node
dropdown = pn.widgets.Select(name="Select Node", options=[None] + list(G.nodes()), value=None)

# Binding dropdown menu to view()
interactive_view = pn.bind(view, hover_node=dropdown)

# Show visualization
pn.Column(dropdown, interactive_view).servable()

## Hover and Click Selection

### With Dash:

References: https://dash.plotly.com/tutorial

- Hover is achieved by hoverData, which triggers an event each time the cursor hovers over a node.
- Click is achieved by storing selected node(s) so that we can toggle back and forth between selecting and unselecting.

FIX: hover and click events not behaving as expected. For connected nodes, you can only unselect them after selecting at least one other node.

In [317]:
# setting up Dash app
app = dash.Dash(__name__)

# App layout
app.layout = html.Div([
    dcc.Graph(
        id = 'network-graph', # figure
        figure = create_plot(),
        style = {'margin-top': '50px'}
    ),
    html.Button( # reset button (Dash not compatible with ipywidgets, need html Button instead)
        id = 'reset-button',
        children = 'Reset Network',
        style = {
            'position': 'absolute',
            'top': '20px',
            'left': '20px',
            'background-color': 'skyblue',
            'color': 'white',
            'border': 'none',
            'borderRadius': '5px',
            'padding': '5px 5px',
            'cursor': 'pointer'
        }
    ),
    dcc.Store(
        id = 'selected-nodes', # place to store selected nodes
        data = []
    ),
])

@app.callback(
     [Output('network-graph', 'figure'),   # Output for figure
      Output('selected-nodes', 'data')],   # Output for node storage
     [Input('network-graph', 'hoverData'), # Input for hover event
      Input('network-graph', 'clickData'), # Input for click event
      Input('reset-button', 'n_clicks')],  # Input for reset button click
     [State('selected-nodes', 'data'),     # State of selected-nodes
      State('network-graph', 'figure')],   # State of figure
)
def fig_update(hoverData, clickData, n_clicks, data, figure):
    '''
    returns: fig, clicked_nodes
    '''
    fig = go.Figure(figure)
    event = dash.callback_context
    trig_event = event.triggered[0]['prop_id']
    
    if 'network-graph' in trig_event:
        # Identify selected node (using location in case node 'text' is not unique)
        selected_node = pos_map[(hoverData['points'][0]['x'], hoverData['points'][0]['y'])]
        node_ind = list(G.nodes()).index(selected_node)
    
    if trig_event == 'network-graph.hoverData' and hoverData and 'points' in hoverData:
        fig = create_plot(hover_node=selected_node)
    
    if trig_event == 'network-graph.clickData' and clickData and 'points' in clickData:
        # toggle back to normal if already selected
        if selected_node in data:
            data.remove(selected_node)
            
            # resetting color back to original color
            orig_colors = [col_pal[list(set(pools.values())).index(pools[n])] for n in G.nodes()]
            for trace in fig.data:
                if 'marker' in trace and 'markers' in trace.mode:
                    curr_colors = list(trace.marker.color)
                    curr_colors[node_ind] = orig_colors[node_ind]
                    trace.marker.color = curr_colors
                    
        # else add node to selected-nodes
        else:
            data.append(selected_node)
    
    if trig_event == 'reset-button.n_clicks':
        fig = create_plot()
        data = []
        
    for clicked_node in data:
        clicked_ind = list(G.nodes()).index(clicked_node)
        
        for trace in fig.data:
            if 'marker' in trace and 'markers' in trace.mode:
                curr_colors = list(trace.marker.color)
                curr_colors[clicked_ind] = 'yellow'
                trace.marker.color = curr_colors
    
    return fig, data

app.run_server()

### With ipywidgets:

references: 

https://plotly.com/python/click-events/

https://github.com/jonmmease/plotly_ipywidget_notebooks/tree/master/notebooks

Problem:
- Slow to update with larger network
- May be because I was unable to remove unseen traces, instead I could only set `trace.visible=False` or else I would get an `IndexError`

In [306]:
# INITIAL CONFIG
 
# node color palette (can change to another palette)
col_pal = px.colors.qualitative.Plotly
    
# getting node sizes, node colors, connection linewidths
activations, rest_act = get_activation(G)
sizes = activation_to_size(activations)
colors = [col_pal[list(set(pools.values())).index(pools[n])] for n in G.nodes()]
linewidths = connection_to_lw(G.edges(data=True))
    
# setting up node outline colors
outline_colors = colors.copy()

# setting changeable colors (will retain clicked color state)
curr_colors = colors.copy()

# initializing click storage
clicked = [False]*len(G.nodes())

In [307]:
# HELPER FUNCTIONS

# --hover update--
def hover_node(trace, points, state):
    temp_outlines = outline_colors.copy()
    temp_colors = curr_colors.copy()
    
    traces = []
    for hidden_ind in points.point_inds:
        hover_node = list(G.nodes())[hidden_ind]
        temp_outlines[hidden_ind] = 'black' # selected node line color
        temp_colors[hidden_ind] = 'yellow' # selected node fill color
        for edge_group, line_width in zip(G.edges(data=True), linewidths):
            u,v,d=edge_group
            x, y = [], []
            if hover_node in {u,v}:
                x += [pos[u][0], pos[v][0], None]
                y += [pos[u][1], pos[v][1], None]
            
                # connection line color
                line_color = 'blue' if pools[u] != pools[v] else 'red'
            
                # node outline colors
                adjust_ind = list(G.nodes()).index(u) if hover_node != u else list(G.nodes()).index(v)
                temp_outlines[adjust_ind] = line_color
            
                traces.append(go.Scatter(
                                x=x, y=y,
                                line=dict(width=line_width, color=line_color),
                                hoverinfo='none',
                                mode='lines',
                                showlegend=False,)
                              )
    with fig.batch_update():
        fig.data[0].marker.color = temp_colors
        fig.data[0].marker.line = dict(width=2, color=temp_outlines)
        for t in traces:
            fig.add_trace(t)
        
# --unhover update--
def unhover_node(trace, points, state):
    # first 3 traces added are (nodes, legend1, legend2)
    # removes all other traces
    temp_colors = curr_colors.copy()
    with fig.batch_update():
        fig.data[0].marker.color = temp_colors
        fig.data[0].marker.line = dict(width=2, color=outline_colors)
        if len(fig.data) > 3:
            for trace in fig.data[3:]:
                trace.visible = False

# --click update--
def click_node(trace, points, state):     
    for node_ind in points.point_inds:
        if clicked[node_ind]:
            clicked[node_ind]=False
            curr_colors[node_ind] = colors[node_ind]
        else:
            clicked[node_ind]=True
            curr_colors[node_ind] = 'yellow'
            
    with fig.batch_update():
        fig.data[0].marker.color = curr_colors

In [308]:
fig = go.FigureWidget()

# Adding Nodes
node_trace = go.Scatter(
        x=[pos[n][0] for n in G.nodes()],
        y=[pos[n][1] for n in G.nodes()],
        text=[f"{n.split(' ')[1]}" for n in G.nodes()],
        marker=dict(size=sizes, 
                    color=colors,
                    line=dict(width=2, color=outline_colors)
                   ),
        hoverinfo="text",
        hovertext=[f"Activation: {activations[n]}<br>Net Input: {activations[n]-rest_act}" for n in G.nodes()],
        mode="markers+text",
        showlegend=False)
fig.add_trace(node_trace)
    
# Adding Legend Labels (dummy traces)
fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="blue"),
        mode="lines",
        name="excitatory",
        showlegend=True
))
fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="red"),
        mode="lines",
        name="inhibitory",
        showlegend=True
))
    
# removing grid and axes
fig.update_layout(
        showlegend=True,
        hovermode="closest",
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        width=700,
        height=700,
        plot_bgcolor="white",
        title = "Interactive Activation and Competition Network"
)

# Reset button functionality (make all colors default)
def reset_plot(button):
    global curr_colors
    global clicked
    with fig.batch_update():
        if len(fig.data) > 3:
            for trace in fig.data[3:]:
                trace.visible=False
        fig.data[0].marker.color = colors
        curr_colors = colors.copy()
        clicked = [False] * len(G.nodes())
        
fig.data[0].on_hover(hover_node)
fig.data[0].on_unhover(unhover_node)
fig.data[0].on_click(click_node)

# Create reset button
reset_button = widgets.Button(description="Reset Plot", button_style="primary")
reset_button.on_click(reset_plot)

# Display the plot and button
display(reset_button, fig)

Button(button_style='primary', description='Reset Plot', style=ButtonStyle())

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'hovertext': [Activation: -6.1<br>Net Input: -6.0, Activation:
                            -5.1<br>Net Input: -5.0, Activation: -7.1<br>Net Input:
                            -7.0, Activation: 1.9<br>Net Input: 2.0, Activation:
                            -0.1<br>Net Input: 0.0],
              'marker': {'color': [#636EFA, #636EFA, #636EFA, #EF553B, #EF553B],
                         'line': {'color': ['#636EFA', '#636EFA', '#636EFA', '#EF553B', '#EF553B'], 'width': 2},
                         'size': [23.333333333333332, 26.666666666666668, 20.0,
                                  50.0, 43.33333333333333]},
              'mode': 'markers+text',
              'showlegend': False,
              'text': [A, B, C, D, E],
              'type': 'scatter',
              'uid': 'add584cd-f1f4-43fe-82e1-bf0a49acbdf2',
              'x': [0, 0, -1, 0, 1],
              'y': [0, -1, 0, 1, 0]},
             {'line': {'color': 'blue',

## Creating Graph from DataFrame object

Testing ability to create starting state graph from a DataFrame. This function should detect the relevant pools and create corresponding positions and colors to use in a plotly interactive visualization.

In [323]:
test = pd.read_csv('data/jets_sharks.csv').sample(10, replace=False, random_state=0)
test

Unnamed: 0,Name,Gang,Age,Edu,Mar,Occupation
2,Sam,Jets,20's,COL.,Sing.,Bookie
24,Ol,Sharks,30's,COL.,Mar.,Pusher
14,Ralph,Jets,30's,J.H.,Sing.,Pusher
17,Nick,Sharks,30's,H.S.,Sing.,Pusher
5,Jim,Jets,20's,J.H.,Div.,Burglar
11,Pete,Jets,20's,H.S.,Sing.,Bookie
23,Rick,Sharks,30's,H.S.,Div.,Burglar
13,Gene,Jets,20's,COL.,Sing.,Pusher
19,Ned,Sharks,30's,COL.,Mar.,Bookie
20,Karl,Sharks,40's,H.S.,Mar.,Bookie


In [324]:
def position_nodes(pools):
    # Reference: Axel Cleeremans
    # place nodes in concentric circles with breaks between pools
    
    pos = {}
    nodes = list(pools.keys())
    pool_ids = list(pools.values())
    pool_lens = {str(p):int(c) for p,c in zip(*np.unique(pool_ids, return_counts=True))}
    
    interval = 800 # distance between nodes in each ring
    center_coord = (0,0) # center of all rings
    radius = 1000 # radius of starting ring (incremented by 1000 each circle)
    
    i = len(nodes)-1 # start from hidden units
    prev_pool = pool_ids[-1]
    while i >= 0:
        circum = 2 * np.pi * radius
        # number of nodes per ring
        ring_size = int(circum / interval)
        
        for j in range(ring_size):
            if i < 0:
                break
            
            # changing pools
            curr_pool = pool_ids[i]
            skip = (curr_pool != prev_pool)
            prev_pool = curr_pool
            
            if skip:
                # moving to a new ring if pool won't fit
                if (j > 0) and (pool_lens[curr_pool] > (ring_size)-(j+1)):
                    break
            
                # else staying in current ring and moving a space
                else:
                    continue
            
            # placing node
            angle = 270 + (j * (360/ring_size))
            x = center_coord[0] + (np.cos(np.radians(angle)) * radius)
            y = center_coord[1] + (np.sin(np.radians(angle)) * radius)
            pos[nodes[i]] = [x,y]
            i -= 1
        
        radius += 1000
                 
    return pos

def init_graph(df, hidden_state = None):
    '''
    Arguments
    - df: DataFrame
          data to plot
    - hidden_state: string
          column to use for hidden nodes
    Returns
    - pos: dict
          mapping from node to position
    - pools: dict
          mapping from node to pool ID
    - G: networkx Graph object
          graph of network (nodes and edges)
    '''
    df = df.copy().dropna()
    assert len(df) > 0, f"DataFrame currently has shape {df.shape}. Must have more than 0 non-null rows."
    
    # default hidden state (first column)
    if hidden_state is None:
        hidden_state = df.columns[0]
    pools = {}
    for c in df.columns:
        nodes = df[c].unique()
        for n in nodes:
            # make unique column/node ID in case columns have identical entries
            pools[f"{c}: {n}"] = c
    for hidden_node in df[hidden_state].unique():
        pools[f"hidden: {hidden_node}"] = 'hidden'
        
    # getting nodes
    nodes = pools.keys()
    
    # getting edges
    edges = []
    # pool nodes only have edges to hidden units and within-pool nodes
    hidden_nodes = [k for k,v in pools.items() if v == 'hidden']
    for c in df.columns:
        pool_nodes = [k for k,v in pools.items() if v == c]
        
        # hidden unit connections
        for n in pool_nodes:
            hidden_connections = df[df[c]==n.split(' ')[1]][hidden_state].to_list()
            edges.extend([(f'hidden: {h}', n, 0) for h in hidden_connections])

        # within-pool connections
        edges.extend([(u, v, 0) for u, v in combinations(pool_nodes, 2)])
        
    # within-pool hidden unit connections
    edges.extend([(u, v, 0) for u, v in combinations(hidden_nodes, 2)])
        
    # creating graph
    G = nx.Graph()
    G.add_nodes_from(nodes)
    for u, v, weight in edges:
        G.add_edge(u, v, weight=weight)
    
    # positions
    pos = position_nodes(pools)
        
    return pos, pools, G

In [325]:
pos, pools, G = init_graph(test)
pos_map = {tuple(v):k for k,v in pos.items()}

Dropdown Visualization

In [326]:
# Creating dropdown menu to select hover_node
dropdown = pn.widgets.Select(name="Select Node", options=[None] + list(G.nodes()), value=None)

# Binding dropdown menu to view()
interactive_view = pn.bind(view, hover_node=dropdown)

# Show visualization
pn.Column(dropdown, interactive_view).servable()

Click/Hover Visualization

In [313]:
# INITIAL CONFIG
 
# node color palette (can change to another palette)
col_pal = px.colors.qualitative.Plotly
    
# getting node sizes, node colors, connection linewidths
activations, rest_act = get_activation(G)
sizes = activation_to_size(activations)
colors = [col_pal[list(set(pools.values())).index(pools[n])] for n in G.nodes()]
linewidths = connection_to_lw(G.edges(data=True))
    
# setting up node outline colors
outline_colors = colors.copy()

# setting changeable colors (will retain clicked color state)
curr_colors = colors.copy()

# initializing click storage
clicked = [False]*len(G.nodes())

In [314]:
fig = go.FigureWidget()

# Adding Nodes
node_trace = go.Scatter(
        x=[pos[n][0] for n in G.nodes()],
        y=[pos[n][1] for n in G.nodes()],
        text=[f"{n.split(' ')[1]}" for n in G.nodes()],
        marker=dict(size=sizes, 
                    color=colors,
                    line=dict(width=2, color=outline_colors)
                   ),
        hoverinfo="text",
        hovertext=[f"Activation: {activations[n]}<br>Net Input: {activations[n]-rest_act}" for n in G.nodes()],
        mode="markers+text",
        showlegend=False)
fig.add_trace(node_trace)
    
# Adding Legend Labels (dummy traces)
fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="blue"),
        mode="lines",
        name="excitatory",
        showlegend=True
))
fig.add_trace(go.Scatter(
        x=[None], y=[None],
        line=dict(width=2, color="red"),
        mode="lines",
        name="inhibitory",
        showlegend=True
))
    
# removing grid and axes
fig.update_layout(
        showlegend=True,
        hovermode="closest",
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        width=700,
        height=700,
        plot_bgcolor="white",
        title = "Interactive Activation and Competition Network",
)

# Reset button functionality (make all colors default)
def reset_plot(button):
    global curr_colors
    global clicked
    with fig.batch_update():
        if len(fig.data) > 3:
            for trace in fig.data[3:]:
                trace.visible=False
        fig.data[0].marker.color = colors
        curr_colors = colors.copy()
        clicked = [False] * len(G.nodes())
        
fig.data[0].on_hover(hover_node)
fig.data[0].on_unhover(unhover_node)
fig.data[0].on_click(click_node)

# Create reset button
reset_button = widgets.Button(description="Reset Plot", button_style="primary")
reset_button.on_click(reset_plot)

# Display the plot and button
display(reset_button, fig)

Button(button_style='primary', description='Reset Plot', style=ButtonStyle())

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'hovertext': [Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Activation: -0.1<br>Net Input:
                            0.0, Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Activation: -0.1<br>Net Input:
                            0.0, Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Activation: -0.1<br>Net Input:
                            0.0, Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Activation: -0.1<br>Net Input:
                            0.0, Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Activation: -0.1<br>Net Input:
                            0.0, Activation: -0.1<br>Net Input: 0.0, Activation:
                            -0.1<br>Net Input: 0.0, Ac

In [327]:
# setting up Dash app
app = dash.Dash(__name__)

# App layout
app.layout = html.Div([
    dcc.Graph(
        id = 'network-graph', # figure
        figure = create_plot(),
        style = {'margin-top': '50px'}
    ),
    html.Button( # reset button (Dash not compatible with ipywidgets, need html Button instead)
        id = 'reset-button',
        children = 'Reset Network',
        style = {
            'position': 'absolute',
            'top': '20px',
            'left': '20px',
            'background-color': 'skyblue',
            'color': 'white',
            'border': 'none',
            'borderRadius': '5px',
            'padding': '5px 5px',
            'cursor': 'pointer'
        }
    ),
    dcc.Store(
        id = 'selected-nodes', # place to store selected nodes
        data = []
    ),
])

@app.callback(
     [Output('network-graph', 'figure'),   # Output for figure
      Output('selected-nodes', 'data')],   # Output for node storage
     [Input('network-graph', 'hoverData'), # Input for hover event
      Input('network-graph', 'clickData'), # Input for click event
      Input('reset-button', 'n_clicks')],  # Input for reset button click
     [State('selected-nodes', 'data'),     # State of selected-nodes
      State('network-graph', 'figure')],   # State of figure
)
def fig_update(hoverData, clickData, n_clicks, data, figure):
    '''
    returns: fig, clicked_nodes
    '''
    fig = go.Figure(figure)
    event = dash.callback_context
    trig_event = event.triggered[0]['prop_id']
    
    if 'network-graph' in trig_event:
        # Identify selected node (using location in case node 'text' is not unique)
        selected_node = pos_map[(hoverData['points'][0]['x'], hoverData['points'][0]['y'])]
        node_ind = list(G.nodes()).index(selected_node)
    
    if trig_event == 'network-graph.hoverData' and hoverData and 'points' in hoverData:
        fig = create_plot(hover_node=selected_node)
    
    if trig_event == 'network-graph.clickData' and clickData and 'points' in clickData:
        # toggle back to normal if already selected
        if selected_node in data:
            data.remove(selected_node)
            
            # resetting color back to original color
            orig_colors = [col_pal[list(set(pools.values())).index(pools[n])] for n in G.nodes()]
            for trace in fig.data:
                if 'marker' in trace and 'markers' in trace.mode:
                    curr_colors = list(trace.marker.color)
                    curr_colors[node_ind] = orig_colors[node_ind]
                    trace.marker.color = curr_colors
                    
        # else add node to selected-nodes
        else:
            data.append(selected_node)
    
    if trig_event == 'reset-button.n_clicks':
        fig = create_plot()
        data = []
        
    for clicked_node in data:
        clicked_ind = list(G.nodes()).index(clicked_node)
        
        for trace in fig.data:
            if 'marker' in trace and 'markers' in trace.mode:
                curr_colors = list(trace.marker.color)
                curr_colors[clicked_ind] = 'yellow'
                trace.marker.color = curr_colors
    
    return fig, data

app.run_server()