# Visualization Module

This notebook contains utilities for visualizing entity relationship graphs using NetworkX and PyVis.

In [None]:
import networkx as nx
from pyvis.network import Network
import pandas as pd
from typing import Optional, Dict, Any
from pathlib import Path
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Color scheme for different entity types
ENTITY_COLORS = {
    'PERSON': '#4CAF50',  # Green
    'ORG': '#2196F3',      # Blue
    'PERSON-ORG': '#FF9800',  # Orange (edge)
    'ORG-PERSON': '#FF9800',
    'PERSON-PERSON': '#9C27B0',  # Purple (edge)
}

DEFAULT_CONFIG = {
    'height': '600px',
    'width': '100%',
    'bgcolor': '#222222',
    'font_color': 'white',
    'node_distance': 420,
    'central_gravity': 0.33,
    'spring_length': 110,
    'spring_strength': 0.10,
    'damping': 0.95
}

In [None]:
def create_networkx_graph(edges_df: pd.DataFrame) -> nx.Graph:
    """
    Create a NetworkX graph from edges DataFrame.
    
    Args:
        edges_df: DataFrame with 'source', 'target', and 'type' columns
    
    Returns:
        NetworkX Graph object
    """
    if edges_df.empty:
        logger.warning("Empty edges DataFrame, creating empty graph")
        return nx.Graph()
    
    G = nx.from_pandas_edgelist(
        edges_df, 
        source='source', 
        target='target',
        edge_attr='type'
    )
    
    logger.info(f"Created graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
    return G

In [None]:
def infer_node_type(edges_df: pd.DataFrame) -> Dict[str, str]:
    """
    Infer node types from edge types.
    
    Args:
        edges_df: DataFrame with edge data
    
    Returns:
        Dictionary mapping node names to their types
    """
    node_types = {}
    
    for _, row in edges_df.iterrows():
        edge_type = row.get('type', '')
        if '-' in edge_type:
            type1, type2 = edge_type.split('-')
            if row['source'] not in node_types:
                node_types[row['source']] = type1
            if row['target'] not in node_types:
                node_types[row['target']] = type2
    
    return node_types

In [None]:
def create_pyvis_network(
    G: nx.Graph,
    edges_df: pd.DataFrame,
    config: Optional[Dict[str, Any]] = None
) -> Network:
    """
    Create a PyVis Network visualization.
    
    Args:
        G: NetworkX Graph object
        edges_df: Original edges DataFrame (for type inference)
        config: Optional visualization configuration
    
    Returns:
        PyVis Network object
    """
    cfg = {**DEFAULT_CONFIG, **(config or {})}
    
    net = Network(
        height=cfg['height'],
        width=cfg['width'],
        bgcolor=cfg['bgcolor'],
        font_color=cfg['font_color'],
        notebook=True,
        cdn_resources='in_line'
    )
    
    # Infer node types for coloring
    node_types = infer_node_type(edges_df)
    
    # Add nodes with colors
    for node in G.nodes():
        node_type = node_types.get(node, 'PERSON')
        color = ENTITY_COLORS.get(node_type, '#FFFFFF')
        net.add_node(node, label=node, color=color, title=f"{node} ({node_type})")
    
    # Add edges
    for edge in G.edges(data=True):
        source, target, data = edge
        edge_type = data.get('type', '')
        edge_color = ENTITY_COLORS.get(edge_type, '#888888')
        net.add_edge(source, target, color=edge_color)
    
    # Configure physics
    net.repulsion(
        node_distance=cfg['node_distance'],
        central_gravity=cfg['central_gravity'],
        spring_length=cfg['spring_length'],
        spring_strength=cfg['spring_strength'],
        damping=cfg['damping']
    )
    
    return net

In [None]:
def save_graph_html(
    net: Network,
    output_path: str = '/tmp/pyvis_graph.html'
) -> str:
    """
    Save the PyVis network as an HTML file.
    
    Args:
        net: PyVis Network object
        output_path: Path for the output HTML file
    
    Returns:
        Path to the saved file
    """
    # Ensure parent directory exists
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    
    net.save_graph(output_path)
    logger.info(f"Graph saved to {output_path}")
    
    return output_path

In [None]:
def visualize_graph(
    edges_df: pd.DataFrame,
    output_path: str = '/tmp/pyvis_graph.html',
    config: Optional[Dict[str, Any]] = None
) -> str:
    """
    Main function to create and save a graph visualization.
    
    Args:
        edges_df: DataFrame with 'source', 'target', and 'type' columns
        output_path: Path for the output HTML file
        config: Optional visualization configuration
    
    Returns:
        Path to the saved HTML file
    """
    G = create_networkx_graph(edges_df)
    net = create_pyvis_network(G, edges_df, config)
    return save_graph_html(net, output_path)

In [None]:
def get_graph_statistics(G: nx.Graph) -> Dict[str, Any]:
    """
    Calculate basic graph statistics.
    
    Args:
        G: NetworkX Graph object
    
    Returns:
        Dictionary with graph statistics
    """
    if G.number_of_nodes() == 0:
        return {'nodes': 0, 'edges': 0}
    
    stats = {
        'nodes': G.number_of_nodes(),
        'edges': G.number_of_edges(),
        'density': nx.density(G),
        'connected_components': nx.number_connected_components(G),
    }
    
    # Top nodes by degree
    degree_dict = dict(G.degree())
    top_nodes = sorted(degree_dict.items(), key=lambda x: x[1], reverse=True)[:10]
    stats['top_nodes_by_degree'] = top_nodes
    
    return stats

## Example Usage

In [None]:
# Example: Visualize sample graph
# sample_edges = pd.DataFrame([
#     {'source': 'Elon Musk', 'target': 'Tesla', 'type': 'PERSON-ORG'},
#     {'source': 'Elon Musk', 'target': 'SpaceX', 'type': 'PERSON-ORG'},
#     {'source': 'Jeff Bezos', 'target': 'Amazon', 'type': 'PERSON-ORG'},
# ])
# output_path = visualize_graph(sample_edges)
# print(f"Graph saved to: {output_path}")