# Graph Building Module

This notebook contains utilities for building entity relationship graphs from NLP-processed data.

In [None]:
import pandas as pd
from itertools import combinations
from typing import List, Dict, Set, Tuple, Any
from dataclasses import dataclass
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class Edge:
    """Represents an edge between two entities."""
    source: str
    target: str
    edge_type: str
    
    def to_dict(self) -> dict:
        return {
            'source': self.source,
            'target': self.target,
            'type': self.edge_type
        }
    
    def __hash__(self):
        # Make edges undirected for deduplication
        return hash(frozenset([self.source, self.target]))
    
    def __eq__(self, other):
        if not isinstance(other, Edge):
            return False
        return frozenset([self.source, self.target]) == frozenset([other.source, other.target])

In [None]:
def is_valid_entity_pair(entity1: dict, entity2: dict) -> bool:
    """
    Check if an entity pair should be linked.
    We skip ORG-ORG pairs and same-entity pairs.
    
    Args:
        entity1: First entity dictionary
        entity2: Second entity dictionary
    
    Returns:
        True if the pair is valid
    """
    # Skip same entities
    if entity1['name'] == entity2['name']:
        return False
    
    # Skip empty names
    if not entity1['name'] or not entity2['name']:
        return False
    
    # Skip ORG-ORG pairs
    if entity1['label'] == 'ORG' and entity2['label'] == 'ORG':
        return False
    
    return True

In [None]:
def create_edge(entity1: dict, entity2: dict) -> Edge:
    """
    Create an edge between two entities.
    
    Args:
        entity1: First entity dictionary
        entity2: Second entity dictionary
    
    Returns:
        Edge object
    """
    edge_type = f"{entity1['label']}-{entity2['label']}"
    return Edge(
        source=entity1['name'],
        target=entity2['name'],
        edge_type=edge_type
    )

In [None]:
def extract_sentence_edges(entities: List[dict]) -> Set[Edge]:
    """
    Extract edges from entities that appear in the same sentence.
    
    Args:
        entities: List of entity dictionaries
    
    Returns:
        Set of Edge objects
    """
    edges = set()
    
    # Group entities by (url, sent_idx)
    sentence_entities: Dict[Tuple[str, int], List[dict]] = {}
    
    for entity in entities:
        key = (entity['urls'], entity['sent_idx'])
        if key not in sentence_entities:
            sentence_entities[key] = []
        sentence_entities[key].append(entity)
    
    # Create edges for entities in the same sentence
    for key, sent_entities in sentence_entities.items():
        if len(sent_entities) < 2:
            continue
        
        # Generate all pairs
        for e1, e2 in combinations(sent_entities, 2):
            if is_valid_entity_pair(e1, e2):
                edges.add(create_edge(e1, e2))
    
    return edges

In [None]:
def extract_coreference_edges(
    entities: List[dict],
    coref_chains: List[dict]
) -> Set[Edge]:
    """
    Extract edges from entities that are linked by coreference chains.
    
    Args:
        entities: List of entity dictionaries
        coref_chains: List of coreference chain dictionaries
    
    Returns:
        Set of Edge objects
    """
    edges = set()
    
    # Group entities by (url, sent_idx)
    entity_by_sentence: Dict[Tuple[str, int], List[dict]] = {}
    
    for entity in entities:
        key = (entity['urls'], entity['sent_idx'])
        if key not in entity_by_sentence:
            entity_by_sentence[key] = []
        entity_by_sentence[key].append(entity)
    
    # Process coreference chains
    for chain in coref_chains:
        url = chain['url']
        chain_sentences = chain['sentences']
        
        # Collect all entities from sentences in this chain
        chain_entities = []
        for sent_idx in chain_sentences:
            key = (url, sent_idx)
            if key in entity_by_sentence:
                chain_entities.extend(entity_by_sentence[key])
        
        # Create edges between entities in the chain
        if len(chain_entities) >= 2:
            for e1, e2 in combinations(chain_entities, 2):
                if is_valid_entity_pair(e1, e2):
                    edges.add(create_edge(e1, e2))
    
    return edges

In [None]:
def build_graph_from_processed_data(processed_articles: List[Dict[str, Any]]) -> pd.DataFrame:
    """
    Build a graph from NLP-processed article data.
    
    Args:
        processed_articles: List of processed article dictionaries
    
    Returns:
        DataFrame with columns: source, target, type
    """
    all_edges = set()
    
    # Collect all entities and coreference chains
    all_entities = []
    all_coref_chains = []
    
    for article in processed_articles:
        all_entities.extend(article.get('entities', []))
        all_coref_chains.extend(article.get('coreference_chains', []))
    
    # Extract edges from same-sentence co-occurrences
    sentence_edges = extract_sentence_edges(all_entities)
    all_edges.update(sentence_edges)
    logger.info(f"Found {len(sentence_edges)} edges from sentence co-occurrences")
    
    # Extract edges from coreference chains
    coref_edges = extract_coreference_edges(all_entities, all_coref_chains)
    all_edges.update(coref_edges)
    logger.info(f"Found {len(coref_edges)} additional edges from coreference chains")
    
    # Convert to DataFrame
    edge_list = [edge.to_dict() for edge in all_edges]
    
    if not edge_list:
        return pd.DataFrame(columns=['source', 'target', 'type'])
    
    df = pd.DataFrame(edge_list)
    logger.info(f"Built graph with {len(df)} total edges")
    
    return df

In [None]:
def get_entity_statistics(processed_articles: List[Dict[str, Any]]) -> pd.DataFrame:
    """
    Get statistics about extracted entities.
    
    Args:
        processed_articles: List of processed article dictionaries
    
    Returns:
        DataFrame with entity counts
    """
    all_entities = []
    for article in processed_articles:
        all_entities.extend(article.get('entities', []))
    
    df = pd.DataFrame(all_entities)
    
    if df.empty:
        return pd.DataFrame()
    
    # Count occurrences
    counts = df.groupby(['name', 'label']).size().reset_index(name='count')
    counts = counts.sort_values('count', ascending=False)
    
    return counts

## Example Usage

In [None]:
# Example: Build graph from sample data
# sample_processed = [
#     {
#         'url': 'https://example.com',
#         'entities': [
#             {'name': 'Elon Musk', 'label': 'PERSON', 'sent_idx': 0, 'start': 0, 'end': 2, 'urls': 'https://example.com'},
#             {'name': 'Tesla', 'label': 'ORG', 'sent_idx': 0, 'start': 4, 'end': 5, 'urls': 'https://example.com'},
#         ],
#         'coreference_chains': []
#     }
# ]
# edges_df = build_graph_from_processed_data(sample_processed)
# print(edges_df)