In [None]:
import networkx as nx
import quantlaw.utils.networkx
import pandas as pd
import seaborn as sns
import lxml
import json
import copy
from collections import defaultdict

# Preprocessing

- Take a crossref graph (in which crossreferences are modeled at the lowest possible level at source and target. If a detailed citekey cannot be resolved the last component of the citekey is removed and we try to resolve the broader reference. E.g. if `26_7604_h_3` cannot be found we try to match to `26_7604_h` and then to `26_7604` )
- Remove all containment edges which targets are not subseqitems. -> Edges in the remaining graph represent the hierarchy between seqitems and subseqitems (and the cross-references.) (We don't need higher containment edges to follow cross-references.)
- Remove all nodes larger than 1000 tokens. This way we ignore overbroad references.

In [None]:
def subseqitem_mapping(hierarchy_g):
    '''
    Create a mapping from subseqitem to their parent seqitem. 
    The mapping also includes a mapping from a seqitem to itself.
    '''
    subseqitems2seqitems = {}
    for n, node_type in sorted(hierarchy_g.nodes(data='type')):
        if node_type == 'seqitem':
            subseqitems2seqitems[n] = n
            for descendant in nx.descendants(hierarchy_g, n):
                subseqitems2seqitems[descendant] = n
    return subseqitems2seqitems

def filtered_graph(g, token_threshold):
    '''
    Remove containment edges above seqitem level 
    and larger that the `token_threshold`.
    
    '''
    node_types = nx.get_node_attributes(g, 'type')
    g.remove_edges_from([
        (u, v, k)
        for u, v, k, edge_type in g.edges(keys = True, data='edge_type')
        if not (
            node_types[v] == 'subseqitem' or 
            edge_type =='reference'
        )
    ])
    
    g.remove_nodes_from([
        n
        for n, tokens_n in g.nodes(data='tokens_n')
        if tokens_n and tokens_n > token_threshold
    ])
    
    nx.set_edge_attributes(g, {
        (u, v , k): 1 if edge_type == 'reference' else 0
        for u, v , k, edge_type in g.edges(keys=True, data='edge_type')
    }, 'dist_weight')

    
def direct_path_graph(g):
    '''
    Create a new graph with the same nodes as g and for
    each reference edge in g add edges from its source the the target and 
    all nodes contained in the target.
    Do not add containment edged.
    Remove self-loops.
    '''
    h = nx.DiGraph()
    
    direct_edges = [
        (u, v) 
        for u, v, edge_type in g.edges(data='edge_type') 
        if edge_type == 'reference'
    ]
    direct_edges = [
        (u, v, {'edge_key': i})
        for i, (u, v) in enumerate(direct_edges)
    ]

    h.add_edges_from(direct_edges)
    for u, v, edge_type in sorted(
        g.edges(data='edge_type'), 
        key=lambda e: e[0]
    ):
        if edge_type == 'containment':
            for citing_node, _, edge_key in h.in_edges(u, data='edge_key'):
                h.add_edge(citing_node, v, edge_key=edge_key)
    
    tokens_n = nx.get_node_attributes(g, 'tokens_n')
    nx.set_node_attributes(h, {n: tokens_n[n] for n in h.nodes}, 'tokens_n')

    h.remove_edges_from([(u, v) for u, v in h.edges() if u == v])
    return h

def remove_contained_nodes(nodes, descendants_cache, contained_edges):
    '''
    Gives a set of nodes (`nodes`). Gerenate a subset that does not a contain a node of the
    given set if an ancestors (in the hierarchy) of this node is also part of the set.
    '''
    contained_nodes = set()
    for node in nodes:
        contained_nodes.update(get_descendants(node, descendants_cache, contained_edges))
    return nodes - contained_nodes


def get_containment_edges(g):
    '''
    Create a dict mapping a node to a set of its direct children in the hierarchy graph.
    '''
    contained_edges = dict()
    for u, v, edge_type in g.edges(data='edge_type'):
        if edge_type == 'containment':
            if u in contained_edges:
                contained_edges[u].add(v)
            else:
                contained_edges[u] = {v}
    return contained_edges


def get_descendants(node, descendants_cache, contained_edges):
    '''
    Get all descendants for node. Using an contained_edges dict and descendants_cache dict 
    makes this function efficient when using get_descendants multiple times. 
    '''
    if node not in descendants_cache:
        descendants = set()
        if node in contained_edges:  # Node has children
            for child in contained_edges[node]:
                descendants.add(child)
                descendants.update(get_descendants(child, descendants_cache, contained_edges))
        descendants_cache[node] = descendants
    return descendants_cache[node]

# Overlapping sets

In [None]:
def generate_stats(
    nodes_without_contained, ref_edges, descendants_cache, contained_edges, subseqitems2seqitems, g_tokens_n
):
    nodes_len = len(nodes_without_contained)
    
    seqitems = {subseqitems2seqitems[n] for n in nodes_without_contained}
    seqitems_len = len(seqitems)

    titles = {n.split('_')[0] for n in nodes_without_contained}
    title_len = len(titles)
    
    ref_edges_len = len(ref_edges)
    
    tokens_n = int(sum(g_tokens_n[n] for n in nodes_without_contained))
    
    return dict(
        root=n,
        year=year,
        nodes=nodes_len,
        ref_edges=ref_edges_len,
        seqitems=seqitems_len,
        titles=title_len,
        tokens_n=tokens_n
    )

In [None]:
for year in range(1998, 2020):
    year = str(year)
    
    g = quantlaw.utils.networkx.load_graph_from_csv_files(
        '../../legal-networks-data/us/4_crossreference_graph/detailed/',
        year,
        filter=None
    )
    subseqitems2seqitems = subseqitem_mapping(quantlaw.utils.networkx.hierarchy_graph(g))

    filtered_graph(g, token_threshold=1000)
    h = quantlaw.utils.networkx.hierarchy_graph(g)
    h.remove_node('root')
    descendants_cache = {}
    contained_edges = get_containment_edges(g)
    g = direct_path_graph(g)

    root_nodes = {n for n, degree in h.in_degree() if degree == 0}
    
    base_nodes = remove_contained_nodes(
        g.nodes, descendants_cache, contained_edges
    )
    
    edge_key_dict = defaultdict(dict)
    for (u, v), k in nx.get_edge_attributes(g, 'edge_key').items():
        edge_key_dict[u][v] = k
    edge_key_dict = dict(edge_key_dict)

    for radius in [
        None, 
#         3, 
#         6
    ]:
        g_tokens_n = nx.get_node_attributes(g, 'tokens_n')
        data = []
        for i, n in enumerate(base_nodes):
            
            shortest_path_dict = nx.shortest_path(g, n)
            
            reachable_nodes_and_self = {
                path_root 
                for path_root, path in shortest_path_dict.items()
                if not radius or len(path) <= radius + 1
            } | {n}
            
            nodes_without_contained = remove_contained_nodes(
                reachable_nodes_and_self, descendants_cache, contained_edges
            ) 
            
            ref_edges = {
                k
                for u in reachable_nodes_and_self
                for v, k in edge_key_dict.get(u, {}).items()
                if v in reachable_nodes_and_self
            }
            
            if nodes_without_contained:
                n_data = generate_stats(
                    nodes_without_contained, ref_edges, descendants_cache, contained_edges, 
                    subseqitems2seqitems, g_tokens_n
                )
                data.append(n_data)
        df = pd.DataFrame(data).sort_values('root')
        df.to_csv(
            f'../data/reference_sets_{year}_radius_{radius}.csv'
            if radius else
            f'../data/reference_sets_{year}.csv'
        )
    
    print('Done', year)