In [None]:
import pandas as pd
import networkx as nx
import json
from pathlib import Path
from unidecode import unidecode
from networkx.drawing.nx_pydot import graphviz_layout

prefix = 'data/umimematikucz-system_'

kc = pd.read_csv(Path(prefix + "kc.csv", ), sep=';', header=0, index_col=0)
follow = pd.read_csv(Path(prefix + "kc_follow.csv", ), sep=';', header=0, index_col=0)
#ps = pd.read_csv(Path(prefix + "ps.csv", ), sep=';', header=0, index_col=0)

SIZE = (100, 100)
def recursively_change_shape(node, shape):
    nx.set_node_attributes(G, {node:shape}, name='shape')
    for pred in G.predecessors(node):
        recursively_change_shape(pred, shape)

shapes = ['diamond', 'triangle', 'rectangle', 'ellipse', 'octagon', 'v', 'hexagon', 'parallelogram', 'round-rectangle']
G = nx.DiGraph(size=SIZE)

G.add_node(0, label='"Root"', color='black', shape='ellipse')
G.add_nodes_from([(int(idx), {'label':f'"{unidecode(name)}"', 'color': 'black', 'shape':'ellipse'}) for idx, name in zip(kc.index, kc['name'])])
G.add_edges_from([(int(idx), int(parent)) for idx, parent in zip(kc.index, kc['parent'])], color='black')

for idx, shape in zip(G.predecessors(0), shapes):
    recursively_change_shape(idx, shape)

G.add_edges_from([(int(first), int(second)) for first, second in follow[['kc1', 'kc2']].to_numpy()], color='silver')
#pos = nx.spring_layout(G, scale=scale, center=(scale//2, scale//2), k=0.1)
pos = graphviz_layout(G, prog="dot")

In [None]:
def is_element_edge(element):
    return 'source' in element['data']

def graph_index_from_elements(element):
    if 'data' in element:
        return int(element['data']['id'][1:])
    return int(element['id'][1:])

def edge_pair_from_elements(element):
    return int(element['data']['source'][1:]), int(element['data']['target'][1:])

def garph_index_to_elements(idx, edge=False):
    return f'e{idx}' if edge else f'n{idx}'


def graph_edge_as_elements(source, target, params):
    return {
        'data': {
            'source': garph_index_to_elements(source),
            'target': garph_index_to_elements(target),
            'color': params['color']
        }
    }

def graph_as_elements(G, pos):
    elements = [{
            'data': {
                'id': garph_index_to_elements(idx),
                'label':params['label'],
                'color': params['color']
            } | params, 
            'position': {'x': pos[idx][0], 'y': pos[idx][1]}
        } for idx, params in G.nodes(data=True) ]
    elements += [graph_edge_as_elements(s, t, params) for s, t, params in G.edges(data=True)]
    return elements

def graph_from_elements(elements):
    G = nx.DiGraph(size=SIZE)
    G.add_nodes_from([(graph_index_from_elements(elem), elem['data'])
                      for elem in elements if 'source' not in elem['data']])
    G.add_edges_from([edge_pair_from_elements(elem) for elem in elements if 'source' in elem['data'] and elem['data']['color'] == 'black'], color='black')
    G.add_edges_from([edge_pair_from_elements(elem) for elem in elements if 'source' in elem['data'] and elem['data']['color'] == 'silver'], color='silver')
    return G

In [None]:
graph_from_elements(graph_as_elements(G, pos)).nodes(data=True)[1]

In [None]:
import dash_cytoscape as cyto
import dash_bootstrap_components as dbc
import src.custom_components as cc
from dash import dcc, html, Input, Output
from dash import callback_context as ctx
from jupyter_dash import JupyterDash
from dash.exceptions import PreventUpdate

app = JupyterDash(__name__, external_stylesheets=[dbc.themes.FLATLY])

styles = json.load(open('styles.json', 'r'))

options = [{'label':val, 'value':val} for val in ['dot', 'fdp', 'twopi', 'circo']] # twopi, sfdp, circo
tool_panel = dbc.Card(dbc.CardBody(dbc.Row([
    dbc.Col(dbc.Select(
        id='layout_select',
        placeholder=f"Layout ({options[0]['label']})", # options[0]['label']
        options=options,
    ), width=2),
    dbc.Col(dbc.RadioItems(
        id="mode_btn",
        className="btn-group",
        inputClassName="btn-check",
        labelClassName="btn btn-outline-primary",
        labelCheckedClassName="active",
        options=[
            {"label": name, "value": val} for val, name in enumerate(['Explore', 'Add/Remove edges', 'TBD'])
        ],
        value=0
    ), width=6),
    dbc.Col(
        dcc.Upload(dbc.Button('Upload Graph'), id='upload_graph'),
        width=2
    ),
    dbc.Col(
        dbc.Button('Download Graph', id='download_btn'),
        width=2
    ),
])))

main_graph = dbc.Row(dbc.Col(
    html.Div(cyto.Cytoscape(
        id='cytoscape',
        elements=[],
        style=styles['cytoscape'],
        layout={'name': 'preset', 'directed': True},
        responsive=False,
        autoRefreshLayout=False,
        stylesheet=styles['stylesheet']
    ))
))

meta = html.Div([
    #html.Div(id = 'last_clicked', style={'display':'none'}),
    html.Div(id='test'),
    dcc.Download(id='download'),
    dcc.Store(id='graph_elems'),
])

app.layout = html.Div(dbc.Container([
    tool_panel,
    html.Br(),
    main_graph,
    html.Br(),
    meta,
]))

In [None]:
from base64 import b64decode

@app.callback(
    Output('graph_elems', 'data'),
    Input('upload_graph', 'contents'),
    prevent_initial_call=True,
)
def upload_graph(upload):
    # TODO more robust solution
    # return json.loads(b64decode(upload).decode('latin1'))
    return graph_as_elements(G, pos)


@app.callback(
    Output('download', 'data'),
    Input('graph_elems', 'data'),
    Input('download_btn', 'n_clicks'),
    prevent_initial_call=True,
)
def download_graph(data, _):
    if ctx.triggered_id == 'download_btn':
        return dcc.send_string(json.dumps(data), filename='updated_graph.json')
    raise PreventUpdate


def filter_for_edge(edge):
    return lambda e: 'source' not in e['data'] or \
        'target' not in e['data'] or \
        e['data']['source'] != edge['data']['source'] or \
        e['data']['target'] != edge['data']['target']


@app.callback(
    Output('test', 'children'),
    Input('cytoscape', 'elements'),
)
def test_out(inp):
    return ''


def modify_by_id(elements, id, key, value):
    elems = [elem for elem in elements if elem['data']['id'] == id]
    if len(elems) != 1:
        raise RuntimeError(f'Detected {len(elems)} elements for {id=}. Expected 1.')
    elem = elems[0]
    elem['data'][key] = value
    return [elem for elem in elements if elem['data']['id'] != id] + [elem]

@app.callback(
        Output('cytoscape', 'elements'),
        Input('mode_btn', 'value'),
        Input('cytoscape', 'selectedNodeData'),
        Input('cytoscape', 'elements'),
        Input('graph_elems', 'data'),
        Input('layout_select', 'value'),
        prevent_initial_call=True,
)
def update_graph(mode, selected, elements, new_elements, layout):
    if ctx.triggered_id == 'graph_elems':
        graph = graph_from_elements(new_elements)
        position = graphviz_layout(graph, prog=layout)
        return graph_as_elements(graph, position)
    
    if ctx.triggered_id == 'layout_select':
        graph = graph_from_elements(elements)
        position = graphviz_layout(graph, prog=layout)
        return graph_as_elements(graph, position)

    if ctx.triggered[0]['prop_id'].split('.')[1] == 'selectedNodeData':
        graph = graph_from_elements(elements)
        for i in range(len(elements)):
            if is_element_edge(elements[i]):
                continue
            elif any(elements[i]['data']['id'] == s['id'] for s in selected):
                elements[i]['data']['color'] = 'red'
            elif any(graph_index_from_elements(elements[i]) in graph.successors(graph_index_from_elements(s)) for s in selected):
                elements[i]['data']['color'] = 'orange'
            elif any(graph_index_from_elements(elements[i]) in graph.predecessors(graph_index_from_elements(s)) for s in selected):
                elements[i]['data']['color'] = 'yellow'
            else:
                elements[i]['data']['color'] = 'black'

        if mode == 1 and len(selected) >= 2:
            edge = graph_edge_as_elements(graph_index_from_elements(selected[-2]), graph_index_from_elements(selected[-1]), {'color':'blue'})
            old_length = len(elements)
            elements = list(filter(filter_for_edge(edge), elements))
            if len(elements) == old_length:
                elements.append(edge)
            
        return elements

    raise PreventUpdate


In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)