# __Generate PyVis Network Visualization from Edge List__

## __Introduction__
In this notebook explores how we can generate [PyVis](https://pyvis.readthedocs.io/en/latest/tutorial.html) from an edge list. We will use __the Niger Delta GDELT data__ for this purpose. It contains the entity relationships (measured by tonality of the articles in which the entities are co-mentioned) for six consecutive years, and we will generate a separate graph (HTML)for each year. We are particularly interested in the relationship w.r.t. __Exxon, Shell, and Chevron__. The edge list is preprocessed so that we have the precomputed tone scores between the entities we are interested in. However, we still need to filter out low-confidence comentions, i.e. low number of comentions.  

__IMPORTANT:__ We assume the folders are structured as follows,
```
project_home
  |-- data
  |     |-- processed
  |     |-- raw
  |-- examples
  |-- pictures
        |-- ni_sec_visualization
```

In [7]:
# Load libs and define helper functions
import os
from typing import List, Iterator, Union
from collections import namedtuple

import pandas as pd
import numpy as np

import matplotlib
import matplotlib.cm as color_maps

from pyvis.network import Network as VisNetwork

# Data path
EDGE_LIST_FILE = '../data/processed/ni_sec_edge_only.csv'
PPL_LIST_FILE  = '../data/raw/ni_top_ppl.csv'   # Entity name and number of mentions
ORG_LIST_FILE  = '../data/raw/ni_top_org.csv'   # Entity name and number of mentions
OUT_FOLDER     = '../pictures/ni_sec_visualization'

## __Brief Overview of the Edge List__

All code below assumes the edge list is of the following schema,

```python
cols_expected = {
    'year':                int,   
    'entity1':             str, 
    'entity2':             str, 
    'co_mention_count':    int,   # No. of comentions
    'tone_sum':            float, # Shortcut for calculation
    'co_mention_tone_avg': float, # Used for visualization
    'org_flag':            bool,  # ANY ONE of the TWO is org
    'person_flag':         bool   # -- is person
}
```

In [3]:
# Peek into the edge list
pd.read_csv(EDGE_LIST_FILE).head()

Unnamed: 0,year,entity1,entity2,co_mention_count,tone_sum,co_mention_tone_avg,org_flag,person_flag
0,2015,Exxon,Ayodele Fayose,1,-2.261445,-2.261445,1,1
1,2015,Ogba Egbema Ndoni,Chevron,1,-2.904267,-2.904267,1,1
2,2015,Shell,Ahmad Lawan,1,-0.211516,-0.211516,1,1
3,2015,Exxon,Divisional Police,7,-41.098485,-5.871212,1,0
4,2015,Exxon,Donald Duke,2,5.55973,2.779865,1,1


Next, we define two helper containers for organizing the preprocessing and data slicing code snippets `Entity` and `EdgeList`. Please refer to the doc-strings for more details about the helper functions.

In [4]:
# Define a helper container that holds the node properties
#   for visualizations. Shape is determined by wether an 
#   entity is a person (dot) or an organization (square)
Entity = namedtuple(
    typename='Entity', 
    field_names=['name', 'label', 'shape']
)


class EdgeList:
    """This is a helper class encapsulating preprocessing 
        code, helper functions, and shortcuts for preparing 
        data before generating the plot.
        
    The general workflow is,
        1. preprocess the whole list as a whole, e.g., remove 
            low-confidence edges. This is done in the initialization
            and `preprocess()`.
        2. keep an interface `get_data_of_year()` to subset the data 
            from that particular year and return a new `EdgeList` object.
    """
    
    def __init__(self, edge_list: Union[str, pd.DataFrame], n_low_conf: int = 10) -> None:
        """Load and pre-process edge list, also collect metadata 
            for downstream processing"""
            
        # Expected columns in source edge list df
        self.n_low_conf = n_low_conf
        cols_expected = {
            'year':                int, 
            'entity1':             str, 
            'entity2':             str, 
            'co_mention_count':    int, 
            'tone_sum':            float, 
            'co_mention_tone_avg': float, 
            'org_flag':            bool, 
            'person_flag':         bool
        }
        
        # Initialize edge list data frame
        if isinstance(edge_list, str):
            if not os.path.exists(edge_list):
                raise FileNotFoundError(f'Cannot find edge list source file @ <{edge_list}>')
            df_edge = pd.read_csv(edge_list, usecols=list(cols_expected.keys()), dtype=cols_expected)
        elif isinstance(edge_list, pd.DataFrame):
            try:
                df_edge = edge_list.astype(cols_expected)
            except KeyError:
                raise KeyError(f'Expecting <edge_list> pd.DataFrame with columns <{", ".join(cols_expected.keys())}>; ' + 
                    f'got <{", ".join(edge_list.columns)}>')
        else:
            raise ValueError(f'<edge_list> should either be filename str or pd.DataFrame; got <{type(edge_list)}>')

        # Preprocess pipeline
        self.df_edge = self.preprocess(df_edge)
        
    def preprocess(self, df_edge: pd.DataFrame) -> pd.DataFrame:
        """Edge list preprocessing pipeline. Called during initialization"""
        
        # Make sure entity names are titled
        df_edge.loc[:, 'entity1'] = df_edge.entity1.str.title()
        df_edge.loc[:, 'entity2'] = df_edge.entity2.str.title()

        # Filter edges with low confidence
        df_edge = df_edge.loc[df_edge.co_mention_count >= self.n_low_conf]
        return df_edge.sort_values('year').reset_index(drop=True)
    
    def get_data_of_year(self, year: int):
        """Subset the data, get edge lists of that particular year"""
        
        return EdgeList(self.df.loc[self.df.year == year])

    # Helpers/Shortcuts ////////////////////////////////////////////////////////////
    @property
    def df(self) -> pd.DataFrame:
        return self.df_edge
    
    @property
    def years(self) -> np.ndarray:
        return np.sort(self.df.year.unique())
    
    @property
    def entities(self) -> List[str]:
        ret = set(self.df.entity1).union(set(self.df.entity2))
        return list(ret)

    @property
    def comentions(self) -> Iterator[pd.core.frame.Any]:
        """Return reach row as a named tuple"""
        
        return self.df.itertuples(index=False, name='Comention')
        
    @property
    def mask_ppl(self) -> np.ndarray:
        """Return boolean index of row with ppl entities only"""
        
        return np.logical_and(
            self.df.person_flag, 
            np.logical_not(self.df.org_flag)
        )

    @property
    def mask_org(self) -> np.ndarray:
        """Return boolean index of row with org entities only"""
        
        return np.logical_and(
            np.logical_not(self.df.person_flag), 
            self.df.org_flag
        )
        
    @property
    def mask_mix(self) -> np.ndarray:
        """Return boolean index of row with both ppl AND org entities only"""
        
        return np.logical_and(
            self.df.person_flag, 
            self.df.org_flag
        )

## __Draw Visualization__

Please read their [tutorials](https://pyvis.readthedocs.io/en/latest/tutorial.html) for more detailed information.

In [18]:
class Canvas:
    
    @staticmethod
    def edge_color(weight: float) -> str:
        """Helper function that maps edge weights to 
            proper color scale."""
            
        if np.isnan(weight):
            return '#f0f8ff'  # Return Alice Blues
        
        """* NOTE: Most edges are negative, so we shift max score from 10 to 5
            to make positive edges more visible.""" 
        norm = matplotlib.colors.Normalize(vmin=-10, vmax=5)
        cmap = color_maps.get_cmap('PiYG')
        rgb = cmap(norm(weight))[:3]  # Will return rgba, we take only first 3 so we get rgb
        color = matplotlib.colors.rgb2hex(rgb)
        return color

    def __init__(self, entities: List[Entity], **canvas_kwargs) -> None:
        
        self.canvas = VisNetwork('1000px', '1000px', **canvas_kwargs)
        self.canvas.repulsion(central_gravity=0.1, spring_length=512)
        self.canvas.show_buttons(filter_=['nodes', 'edges', 'physics'])

        self.entities = set()
        self.comentions = set()

        # Add nodes
        for ent in entities:
            self.entities.add(ent.name)
            self.canvas.add_node(
                ent.name,
                value=2,
                shape=ent.shape,
                label=ent.label,
                title=ent.label
            )

    def show(self, fname: str):
        """Dump visualization to the designated HTML file"""
        
        self.canvas.show(fname)
        
    def draw_comention_single(self, entity1: str, entity2: str, weight: float) -> bool:
        """For ad-hoc edges that needs to be added manually"""
        
        # Check if co-mention exists
        if tuple(sorted([entity1, entity2])) in self.comentions:
            return False
        
        width = abs(weight)
        self.canvas.add_edge(
            entity1, entity2,
            title=round(weight, 2),
            value=1 if np.isnan(width) else width,
            color=Canvas.edge_color(weight)
        )
        return True
        
    def draw_comentions(self, comentions) -> None:
        """Helper function to populate and set attributes of the
            network visualizations."""

        # Add edges
        for cmt in comentions:
            if (cmt.entity1 in self.entities) and (cmt.entity2 in self.entities):
                self.comentions.add(tuple(sorted([cmt.entity1, cmt.entity2])))
                
                width = abs(cmt.co_mention_tone_avg)
                self.canvas.add_edge(
                    cmt.entity1, cmt.entity2,
                    title=round(cmt.co_mention_tone_avg, 2),
                    value=1 if np.isnan(width) else width,
                    color=Canvas.edge_color(cmt.co_mention_tone_avg)
                )

## __Run Visualization Code__

In [19]:
# Get set of entities
all_entities = {}

# People (keep top 32)
all_entities.update({
    ent.name: ent for ent in pd.read_csv(PPL_LIST_FILE, usecols=['persons', 'n'])
        .sort_values('n', ascending=False)
        .head(32)
        .persons.str.title()
        .map(lambda n: Entity(n, n, 'dot'))
})

# Organizations (keep top 32)
all_entities.update({
    ent.name: ent for ent in pd.read_csv(ORG_LIST_FILE, usecols=['organizations', 'n'])
        .sort_values('n', ascending=False)
        .head(32)
        .organizations.str.title()
        .map(lambda n: Entity(n, n, 'square'))
})

# Load full edge list for 6 years
all_edge_list = EdgeList(EDGE_LIST_FILE)

# Plot by year
for year in all_edge_list.years:
    canvas = Canvas(all_entities.values(), notebook=True)
    
    # Plot subset
    sub_edge_list = all_edge_list.get_data_of_year(year)
    canvas.draw_comentions(sub_edge_list.comentions)
    
    # Manually add dummy placeholder edges
    for e1 in ['Chevron', 'Exxon', 'Shell']:
        for e2 in all_entities.keys():
            if e1 != e2:
                canvas.draw_comention_single(e1, e2, np.nan)
            
    # Dump file
    canvas.show(os.path.join(OUT_FOLDER, f'ni_sec_vis_{year}.html'))