In [1]:
#pip3 install notebook ipywidgets plotly
#jupyter nbextension enable --py widgetsnbextension
#jupyter nbextension enable --py plotlywidget

In [2]:
import igraph as ig
import pandas as pd
import webbrowser
import json
import urllib.request
import chart_studio.plotly as py
import plotly.graph_objs as go
from ipywidgets import widgets

from utils import *

import sys
sys.path.append("..")
from predict import *
from config import *

import os
parent_dir = os.path.dirname(os.getcwd())

This visualization is inspired from plotly officiel tutorials: https://plot.ly/python/v3/3d-network-graph/

In [3]:
#Load data for the graph
df_node = pd.read_csv(os.path.join(parent_dir,DF_NODE_FILENAME)).drop('Unnamed: 0',axis=1)
df_edge = pd.read_csv(os.path.join(parent_dir,DF_EDGE_FILENAME)).drop('Unnamed: 0',axis=1)

In [4]:
df_node.head()

Unnamed: 0,name,url
0,space research,https://en.wikipedia.org/wiki/space_research
1,space exploration,https://en.wikipedia.org/wiki/space_exploration
2,space race,https://en.wikipedia.org/wiki/space_race
3,space probe,https://en.wikipedia.org/wiki/space_probe
4,u.s. space exploration history on u.s. stamps,https://en.wikipedia.org/wiki/u.s._space_explo...


In [5]:
df_edge.head()

Unnamed: 0,source,target
0,0,431
1,0,432
2,0,429
3,0,82
4,0,433


In [6]:
# Create placeholders for graph information
labels = df_node['name'].tolist()
color_node = ['rgba(128,128,128,0.01)' for i in range(len(labels))]
color_node_original = color_node.copy()
texts_to_show = [None for i in range(len(labels))]
texts_to_show_original = texts_to_show.copy()
size_node = [15 for i in range(len(labels))]
size_node_original = size_node.copy()
color_edge = ['rgba(128,128,128,0.7)' for _ in range(3*len(df_edge))]
color_edge_original = color_edge.copy()

# Create plotting items
# title, annotations, textbox_query, selector, query_answer, g = create_plot_items(df_node,df_edge,labels,color_node,texts_to_show,color_edge)

In [7]:
# import pickle
# with open('graph,pkl','wb') as f:
#     pickle.dump(g,f)

In [8]:
# Load graph and models
title, annotations, textbox_query, selector, query_answer, g = load_graph('graph,pkl')
bert_embedder_spectral, bert_embedder_mean, node2vec_embed, df_node = load_models(parent_dir, bert_mean_filename = 'bert_title.npy')

In [9]:
node2vec_embed.wv.vocab

{'lockheed star clipper': <gensim.models.keyedvectors.Vocab at 0x7f7afeb9c2e8>,
 'space shuttle': <gensim.models.keyedvectors.Vocab at 0x7f7ac63972b0>,
 'criticism of the space shuttle program': <gensim.models.keyedvectors.Vocab at 0x7f7ac63970b8>,
 'space accidents and incidents': <gensim.models.keyedvectors.Vocab at 0x7f7ac63971d0>,
 'dream chaser': <gensim.models.keyedvectors.Vocab at 0x7f7ac63972e8>,
 'hermes (shuttle)': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397320>,
 'eads phoenix': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397358>,
 'hope-x': <gensim.models.keyedvectors.Vocab at 0x7f7ac63973c8>,
 'buran (spacecraft)': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397400>,
 'robotic spacecraft': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397240>,
 'pioneer 10': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397128>,
 'juno (spacecraft)': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397278>,
 'cassini–huygens': <gensim.models.keyedvectors.Vocab at 0x7f7ac6397198>,
 'pionee

In [10]:
def find_index_edges(nodes_ls):
    edges_idx = []
    for source in nodes_ls:
        for target in nodes_ls:
            filtered = df_edge[((df_edge.source == source) & (df_edge.target == target) ) | ( (df_edge.source == target) & (df_edge.target == source) )]
            if len(filtered)!=0:
                edges_idx.append(filtered.index.values[0])
                
    return edges_idx

In [11]:
current_selector = selector.value

# Choose the model
def response(change):
    global current_selector
    current_selector = selector.value
    make_query(textbox_query)
    
# On a new query, compute the predictions and color the nodes accordingly
def make_query(user_query):
    global current_selector
    global color_node_original
    global texts_to_show_original
    global size_node_original
    global color_edge_original
    texts_to_show = texts_to_show_original.copy()
    color_node = color_node_original.copy()
    size_node = size_node_original.copy()
    color_edge = color_edge_original.copy()
    
    query = user_query.value
    if query != "":
        preds = make_prediction(query, current_selector, bert_embedder_spectral, bert_embedder_mean, node2vec_embed, df_node)
        #preds = {10:1,11:0.8,12:0.6,13:0.4,14:0.3}
        if preds == None:
            html_text = create_text([],df_node)
            query_answer.value = html_text
        else:
            dict_colors = compute_color(preds)

            html_text = create_text(list(dict_colors.keys()),df_node)
            query_answer.value = html_text

            for k,v in dict_colors.items():
                size_node[k] = 30
                color_node[k] = v
                texts_to_show[k] = labels[k]
            
            edges_idx = find_index_edges(list(dict_colors.keys()))
            for idx in edges_idx:
                color_edge[3*idx] = 'rgb(255,0,0,1)'
            

        with g.batch_update():
            g.data[1].marker.size = size_node
            g.data[1].marker.color = color_node
            g.data[1].text = texts_to_show
            g.data[0].line.color = color_edge
            
# Open url when clicking on node
def update_point(trace, points, selector):
    if len(points.point_inds) != 0:
        url = g.data[1].customdata[points.point_inds[0]]
        webbrowser.open_new_tab(url)
    
textbox_query.on_submit(make_query)
selector.observe(response, names="value")
g.data[1].on_click(update_point)

with g.batch_update():
    g.data[1].marker.color = color_node_original.copy()
    g.data[1].marker.size = size_node_original.copy()
    g.data[1].text = texts_to_show_original.copy()
    g.data[0].line.color = color_edge_original.copy()

In [12]:
container = widgets.HBox(children=[textbox_query, selector])
widgets.VBox([title, annotations, container, query_answer, g])

VBox(children=(HTML(value='<h3> Wikipedia Recommender System </h3>'), HTML(value='<h4> By clicking on a node, …