In [1]:
import bokeh as bk
from bokeh.io import show
import networkx
from networkx.drawing.nx_agraph import graphviz_layout
from bokeh.models import ColumnDataSource
import bokeh.plotting 
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes, StaticLayoutProvider
from bokeh.models.tools import *
from bokeh.palettes import Spectral4, Spectral8, Paired8
import numpy as np
import pandas as pd
import pickle

In [2]:
users_cleaned = pickle.load(file = open('./pickles/paper_auth_name.pickle', 'rb'))
user_id_map = pickle.load(file = open('./pickles/user_id_maps.pickle', 'rb'))
user_map = pickle.load(file = open('./pickles/user_mappings.pickle', 'rb'))
data_original = pd.read_csv('./pickles/data_pd.csv')

In [3]:
def author_graph(author, user_papers_map, papers_user_map, user_id_map, limit = 1000, depth = 5, cap_per_paper = 5):
    auths = {}
    if author.lower() in user_id_map[0]:
        to_do_q = []
        author_papers = user_papers_map[author.lower()]
        to_do_q.extend(zip(author_papers, [user_id_map[0][author.lower()]] * len(author_papers)))
        to_do_q.append('--')
        cur_depth = 0
        auths[user_id_map[0][author.lower()]] = (-1, user_id_map[0][author.lower()], cur_depth)
        done = set()
        while ((len(auths.keys()) < limit or len(to_do_q) == 0) and cur_depth <= depth):
            pop_paper = to_do_q[0]
            del[to_do_q[0]]
            if pop_paper in done: continue
            if pop_paper == '--':
                pop_paper = to_do_q[0]
                del[to_do_q[0]]
                to_do_q.append('--')
                cur_depth += 1
                print('Current depth:', cur_depth)
            else:
                done.add(pop_paper)
            colleagues = papers_user_map[pop_paper[0]]
            if len(colleagues) > cap_per_paper: colleagues = colleagues[:cap_per_paper+1]
            for colleague in colleagues:
                if not user_id_map[0][colleague.lower()] in auths:
                    auths[user_id_map[0][colleague.lower()]] =  (pop_paper[0], pop_paper[1], cur_depth + 1)
                colleague_papers = user_papers_map[colleague.lower()]
                to_do_q.extend(zip(colleague_papers, [user_id_map[0][colleague.lower()]] * len(colleague_papers)))
        if cur_depth > depth: print('Max depth reached')
        if len(auths.keys()) >= limit: print('Max limit reached')
        return auths
    else:
        print('Author', author, 'not found')
        return None

In [4]:
auths = author_graph('Yann Lecun', user_map, users_cleaned, user_id_map, 100, 4)

Current depth: 1
Max limit reached


In [5]:
bokeh.plotting.output_notebook()

In [6]:
keys = list(auths.keys())
colours = [0]*len(keys)
names = [user_id_map[1][idx].title() for idx in keys]
papers = [0]*len(keys)
search = ""
G = networkx.DiGraph()
G.add_nodes_from(keys)
root_node = 0
for idx, key in enumerate(keys):
    tag = auths[key]
    if not tag[0] == -1:
        G.add_edge(tag[1], key)
        papers[idx] = data_original.iloc[tag[0], :]['title']
    else:
        root_node = key
        papers[idx] = 'Input search query'
        search = names[idx]
    colours[idx] = Paired8[tag[2]]

In [7]:
pos = graphviz_layout(G, prog='twopi', root = root_node)
max_x = max([i[1][0] for i in pos.items()])
min_x = min([i[1][0] for i in pos.items()])
max_y = max([i[1][1] for i in pos.items()])
min_y = min([i[1][1] for i in pos.items()])

hover = bokeh.models.tools.HoverTool(tooltips=[("Author","@author"), ("Source","@source")])
plot = bokeh.plotting.figure(plot_width=1000, plot_height=1000, x_range=bokeh.models.ranges.Range1d(min_x - 20,max_x + 20),\
                             y_range=bokeh.models.ranges.Range1d(min_y - 20,max_y + 20),\
                             tools=[hover, bokeh.models.tools.PanTool(), bokeh.models.tools.WheelZoomTool(), \
                                    bokeh.models.tools.TapTool(), bokeh.models.tools.ResetTool()], \
                             title="Network Graph for " + search.title()
                            )

nodes = G.nodes()
edges = G.edges()
edges_start = [edge[0] for edge in edges]
edges_end = [edge[1] for edge in edges]
node_source = ColumnDataSource(data=dict(index= nodes, 
                                         fill_color = colours, 
                                         author = names, 
                                         source = papers))
edge_source = ColumnDataSource(data=dict(
                                        start=edges_start,
                                        end=edges_end
))
plot.xaxis.visible = False
plot.xgrid.visible = False
plot.yaxis.visible = False
plot.ygrid.visible = False

graph_renderer = bokeh.models.renderers.GraphRenderer()
graph_renderer.node_renderer.data_source.data = node_source.data
graph_renderer.node_renderer.glyph =  bokeh.models.Circle(size=15, fill_color="fill_color")
graph_renderer.node_renderer.selection_glyph =  bokeh.models.Circle(size=15, fill_color=Spectral8[2])
graph_renderer.edge_renderer.data_source.data = edge_source.data
graph_renderer.edge_renderer.glyph =  bokeh.models.MultiLine(line_color="#CCCCCC", line_alpha=0.8, line_width=5)
graph_renderer.edge_renderer.selection_glyph =  bokeh.models.MultiLine(line_color=Spectral4[2], line_width=5)
graph_renderer.layout_provider = StaticLayoutProvider(graph_layout=pos)

graph_renderer.selection_policy = NodesAndLinkedEdges()

plot.renderers.append(graph_renderer)
show(plot)

E-1001 (BAD_COLUMN_NAME): Glyph refers to nonexistent column name: fill_color [renderer: GlyphRenderer(id='68fff9ca-23d6-41f9-ade3-a1bf46c5d760', ...)]
