# Interactive Plots

In [1]:
import pandas as pd
import numpy as np
from matplotlib.pyplot import figure
import networkx as nx
import matplotlib.pyplot as plt


### Reading Wiki Entity and Relation data with score > 0.9

In [2]:
wiki_data = pd.read_csv('../data/model_predictions_wiki_text_0.8_results.csv', index_col=0)
wiki_data = wiki_data.filter(items=['source_term', 'target_term', 'relation', 'score', 'coocc_term'], axis=1)


# filtering for node pairs with score>0.9
score_threshold=0.9
wiki_data = wiki_data[wiki_data.score>=score_threshold]   #.drop_duplicates(keep='first')

# Entity Cooccurrence term
wiki_data['coocc_term'] = wiki_data.apply(
    lambda row: (row.source_term+'-'+row.target_term) if row.source_term > row.target_term
    else (row.target_term+'-'+row.source_term),
    axis=1
)


wiki_data.head(), len(wiki_data)

(  source_term target_term         relation     score      coocc_term
 0     Machine      Energy    has parameter  0.928493  Machine-Energy
 1      Energy     Machine  is parameter of  0.942900  Machine-Energy
 2       Brake        Axle   interacts with  0.994778      Brake-Axle
 3       Brake    Friction    has parameter  0.973993  Friction-Brake
 4        Axle       Brake   interacts with  0.996486      Brake-Axle,
 23889)

### Entity pair-wise Relation Distribution (directed)

In [3]:
wiki_hg_score_grp = wiki_data.copy()
wiki_hg_score_grp['count']=1
wiki_hg_score_grp = wiki_hg_score_grp.groupby(['coocc_term', 'source_term', 'target_term', 'relation', 'score']).count(
).reset_index( #).rename(columns={'coocc_term':'count'}
).sort_values(by='score', ascending=False)
wiki_hg_score_grp.head()

Unnamed: 0,coocc_term,source_term,target_term,relation,score,count
4495,Lever-Bicycle,Lever,Bicycle,is part of,0.999907,6
4791,Locomotive-Brake shoe,Brake shoe,Locomotive,is part of,0.999903,2
10565,Valve-Train,Valve,Train,is part of,0.999891,2
9561,Train-Cable television,Cable television,Train,is part of,0.999883,2
4490,Lever-Bicycle,Bicycle,Lever,consists of,0.999872,6


In [5]:
# of high score Entity Pairs
len(wiki_hg_score_grp)

12297

### Interactivity through plotly

In [8]:
import plotly.graph_objects as go
import networkx as nx

def get_graph(df):
    return nx.from_pandas_edgelist(df, "source_term", "target_term", 
                          edge_attr=True, create_using=nx.MultiDiGraph())

In [10]:
def get_entity_vocab():
    # entity vocab
    entity_type_df = pd.read_csv('../data/wiki_entity_vocab.csv')

    entity_vocab = {}
    for idx,ent_type in entity_type_df.iterrows():
        entity = ent_type.entity
        etype = ent_type.type
        entity_vocab[entity]=etype
    return entity_vocab

In [11]:
def add_entity_types(df, entity_vocab):
    df['source_type'] = df.source_term.apply(lambda entity: entity_vocab[entity] if entity in entity_vocab.keys() else 'unknown')
    df['target_type'] = df.target_term.apply(lambda entity: entity_vocab[entity] if entity in entity_vocab.keys() else 'unknown')
    return df

In [12]:
def create_node_characteristics(df):
    # Create DF for node characteristics
    src_terms_df = df.groupby(['source_term', 'source_type']).size().reset_index()
    tgt_terms_df = df.groupby(['target_term', 'target_type']).size().reset_index()

    src_terms_df.rename(columns={'source_term':'entity', 'source_type':'type', 0:'weight'}, inplace=True)
    tgt_terms_df.rename(columns={'target_term':'entity', 'target_type':'type', 0:'weight'}, inplace=True)

    terms_df = src_terms_df.append(tgt_terms_df).groupby(['entity', 'type']).sum().reset_index(level=1)
    return terms_df

In [13]:
def get_node_sizes(df):
    vocab = get_entity_vocab()
    ent_type_df = add_entity_types(df, vocab)

    # node characteristics
    terms_df = create_node_characteristics(ent_type_df)
    #terms_df = terms_df.reindex(g.nodes())
    terms_df['type'] = pd.Categorical(terms_df['type'])
    
    # Set node size by type
    node_sizes = [5000 if weight in range(3000, 5000) 
                  else 3500 if weight in range(500, 3000) 
                  else 2000 if weight in range(100, 500) 
                  else 1000
                  for weight in terms_df.weight]
    return node_sizes

In [14]:
def get_edge_widths(df, graph):
    
    # edge width by the relationship weight (count)
    entity_edges = df.groupby('coocc_term')['relation'].count()

    edge_widths=[]
    for edge in graph.edges():
        ent_pair = edge[0]+'-'+edge[1] if edge[0]> edge[1] else edge[1]+'-'+edge[0]
        weight = entity_edges.loc[ent_pair]

        width = 6 if weight in range(300, 500) else 3 if weight in range(200, 300) else 1.5 if weight in range(100, 200) else 0.5
        edge_widths.append(width)
    return edge_widths

In [15]:
def create_edge(x, y, text, width):
    edge_trace = go.Scatter(
        x = x,
        y = y,
        line = dict(width=width, color='skyblue'),
        hoverinfo = 'text',
        text = ([text]),
        mode = 'lines'
    )
    return edge_trace

In [16]:
def create_edge_trace(graph, pos, edge_widths):
    edge_x = []
    edge_y = []
    edge_trace = []

    for i, edge in enumerate(graph.edges()):
        ent1 = edge[0]
        ent2 = edge[1]
        x0, y0 = pos[ent1]
        x1, y1 = pos[ent2]

        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
        
        #width = 0.5 # edge_widths[i]
        #edge_trace.append(create_edge([x0,y0,None], [x1,y1,None], "", width))
        
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')
        
        
    return edge_trace

In [17]:
def create_node_trace(graph, pos, node_sizes):
    node_x = []
    node_y = []
    node_text = []
    for node in graph.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale options
            #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Degree of Nodes',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))
    
    # color nodes
    node_adjacencies = []
    #node_text = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        #node_text.append('# of connections: '+str(len(adjacencies[1])))
        #node_text.append(node)
    node_trace.marker.color = node_adjacencies

    node_trace.text = node_text
    return node_trace

In [18]:
def plot_graph(edge_trace, node_trace):
    fig = go.Figure(data=[edge_trace, node_trace],
             layout=go.Layout(
                title='<br>Knowledge Graph for Brake System in Automotive domain',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                annotations=[ dict(
                    text = "MADS CAPSTONE Project (Team Connect)",
                    showarrow=False,
                    xref="paper", yref="paper",
                    x=0.005, y=-0.002 ) ],
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
    fig.show()

### Plotting Interactive Network Graph

In [25]:
def draw(df):
    G = get_graph(df)
    pos=nx.spring_layout(G, k=0.8)

    # edges
    #edge_widths = get_edge_widths(df, G)
    edge_trace = create_edge_trace(G, pos, edge_widths=None)

    # nodes
    #node_sizes = get_node_sizes(df)
    node_trace = create_node_trace(G, pos, node_sizes=None)

    #plot
    plot_graph(edge_trace, node_trace)

In [27]:
import plotly.graph_objects as go
import networkx as nx

# All the nodes and edges in the graph
dd = wiki_hg_score_grp.copy()

# Brake and parameters associated with the Brake
dd1 = wiki_hg_score_grp[(wiki_hg_score_grp.relation=='has parameter')
                                            & (wiki_hg_score_grp.source_term=='Brake')]

# Brake Elements (structure)
dd2 = wiki_hg_score_grp[(wiki_hg_score_grp.relation=='consists of')
                        & (wiki_hg_score_grp.source_term=='Brake')]

draw(dd)


In [22]:
# of edges and Nodes in the Network Graph
len(G.edges()), len(G.nodes())

(12297, 1499)