## Community Sankey Plot

In this notebook, we mainly create Figure 5 and the analogue figures from the SI.
We further inspect the contents of the Miscellaneous category, which are also discussed in the SI.

### Preparations

In [1]:
from collections import Counter, defaultdict
import itertools
import matplotlib.patches as patches
import numpy as np
import regex
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
from matplotlib.colors import ListedColormap
from quantlaw.utils.networkx import hierarchy_graph
from legal_data_clustering.utils.graph_api import (
    cluster_families,
    add_community_to_graph,
    get_clustering_result,
    get_heading_path,
    add_headings_path
)

plt.ioff()

In [2]:
def cluster_family_plt_colors(dataset):
    colors = np.array(
        [np.array([*c, 1]) for c in sns.color_palette("tab20")] +
        [np.array([0.75, 0.75, 0.75, 1])]
    )
    return ListedColormap(colors)


#### Plotting functions

In [3]:
def get_level_node_weights(G, level):
    """
    Gets all weights of nodes of a given level.
    Used e.g. to obtain the total weight of level.
    """
    return [data['weight'] for node_key, data in G.nodes(data=True) if data['bipartite'] == level]

def calc_config(G, spline_node_ratio=0.5):
    """
    Calculate basic parameters to plot based on the input graph G.
    """
    config = dict()
    config['levels'] = sorted(set(nx.get_node_attributes(G, 'bipartite').values()), key=lambda x: x.lower())
    config['levels_node_weight_sum'] = {
        level: 
        sum(get_level_node_weights(G, level))
        for level in config['levels']
    }
    config['x_weight_scale'] = max(config['levels_node_weight_sum'].values()) # number so chars that are represented by 100% of the width of the plot
    config['node_height'] = 1/len(config['levels'])*(1-spline_node_ratio) # height of nodes
    config['spline_height'] = 1/(len(config['levels'])-1)*spline_node_ratio # height of splines
    
    config['level_tick_positions'] = list(reversed([
        (config['node_height'] + config['spline_height']) * idx + config['node_height'] / 2
        for idx, level in enumerate(config['levels'])
    ]))
    return config

def calc_node_positions(G, config, level_node_orders):
    """
    Calculates the positions of the nodes
    """
    positions = {}

    for level_idx, level in enumerate(config['levels']):
        left = 0

        # Center
        left = (1 - config['levels_node_weight_sum'][level]/config['x_weight_scale']) / 2

        for node_idx, node in enumerate(level_node_orders[level]):
            height = config['node_height']
            width = G.nodes[node]['weight']/config['x_weight_scale']

            top = 1 - (config['node_height'] + config['spline_height']) * level_idx
            bottom = top - height

            positions[node] = dict(
                height=height,
                width=width,
                top=top,
                bottom=bottom,
                left=left,
            )

            left += width

    
    return positions

def calc_edge_positions(G, node_positions, config, level_node_orders):
    """
    Calculates the position of the edges.
    """
    edge_positions = {}

    edge_start_left_offset = {node: 0 for node in node_positions.keys()}
    edge_end_left_offset = {node: 0 for node in node_positions.keys()}

    global_node_order = list(itertools.chain.from_iterable(
        [level_node_orders[level] for level in config['levels']]
    ))

    edges = sorted(
        [(u, v) for (u, v) in G.edges], 
        key=lambda x: (
            global_node_order.index(x[0]),
            global_node_order.index(x[1]),
        ),
    )
    for u, v in edges:
        start_offset = edge_start_left_offset[u]
        end_offset = edge_end_left_offset[v]
        
        y0 = node_positions[u]['bottom']
        x0 = node_positions[u]['left'] + start_offset
        yn = node_positions[v]['top']
        xn = node_positions[v]['left'] + end_offset
        width = G.edges[u, v]['weight'] / config['x_weight_scale']

        edge_start_left_offset[u] += width
        edge_end_left_offset[v] += width

        edge_positions[(u, v)] = dict(
            x0=x0,
            xn=xn,
            y0=y0,
            yn=yn,
            width=width,
        )
        
    return edge_positions


def draw_node(ax, left, bottom, top, width, height, label=None, color='k', v_position='center', label_rotation=0):
    """
    Draws a node
    """
    patch = patches.Rectangle(
            (
                left,
                bottom,
            ),
            width,
            height,
            color=color
         )
    patch.set_edgecolor('k')  
    patch.set_linewidth('0')
    ax.add_patch(patch)
    
def draw_label(ax, left, bottom, top, width, height, label=None, color='k', v_position='center', label_rotation=0, min_font_size=3):
    if v_position == 'bottom':
        hpos = bottom + height/3
    elif v_position == 'top':
        hpos = bottom + height/3*2
    else:
        hpos = bottom + height/2

    if label:
        if label_rotation:
            fontsize = 2 + 6 * width / 0.02
            fontsize = min(fontsize, 8)
            if min_font_size:
                fontsize = max(fontsize, min_font_size)
            xoffset = 0.0006
            yoffset = 0
        else:
            fontsize = 6
            xoffset = 0
            yoffset = -0.0006
        ax.annotate(
            str(label),
            (left + width/2 + xoffset, hpos + yoffset),
            color='k', 
            weight='bold', 
            ha='center', 
            va='center',
            fontsize=fontsize,
            rotation=label_rotation,
        )
        
        
def get_blueprint_vals(resolution):
    """
    Creates a blueprint that can be transformed to print the splines.
    """
    x = np.array([0, 0.15, 0.5, 0.85, 1])
    y = np.linspace(0, 1, len(x))
    z = np.polyfit(y, x, 4)
    f = np.poly1d(z)
    blueprint_y_vals = np.linspace(y[0], y[-1], resolution)
    blueprint_x_vals = f(blueprint_y_vals)
    return blueprint_x_vals, blueprint_y_vals


blueprint_x_vals, blueprint_y_vals = get_blueprint_vals(50)


def draw_edge(ax, x0, xn, y0, yn, width, 
             blueprint_x_vals=blueprint_x_vals, blueprint_y_vals=blueprint_y_vals, 
             color='k', alpha=0.5, hatch=None):
    """
    Draws a node
    """
    y_scale = yn - y0
    ty = blueprint_y_vals * y_scale + y0
    x_scale = xn - x0
    tx = blueprint_x_vals * x_scale + x0
    y_new = np.concatenate([ty, ty[::-1], ])
    x_new = np.concatenate([tx, tx[::-1] + width, ])
    result = np.array([x_new, y_new]).transpose()
    ax.add_patch(
        patches.Polygon(result, facecolor=color, alpha=alpha, lw=0, hatch=hatch, edgecolor='k')
    )

#### Preprocessing functions

##### Load graph

In [4]:
def filter_edges(G, threshold, attr='weight'):
    edges_to_remove = [
        (u, v) 
        for u, v, data in G.edges(data=True) 
        if (
            data[attr] < G.nodes[u][attr] * threshold or
            data[attr] < G.nodes[v][attr] * threshold 
        )
    ]
    H = G.copy()
    H.remove_edges_from(edges_to_remove)
    return H

def load_graph(path, nodes_per_year, edge_threshold):
    # Load graph
    orig_G = nx.read_gpickle(path)
    G = orig_G.copy()
    
    # Selecte property for weights in plot
    nx.set_node_attributes(G, nx.get_node_attributes(G, 'tokens_n'), 'weight')
    nx.set_edge_attributes(G, nx.get_edge_attributes(G, 'tokens_n'), 'weight')
    
    # Remove insignificant edges
    G = filter_edges(G, threshold=edge_threshold)
    
    # Get nodes by snapshot
    years = sorted(set(nx.get_node_attributes(G, 'bipartite').values()))
    node_years = [
        sorted(
            [
                (n, data['weight']) 
                for n, data in G.nodes(data=True) 
                if data['bipartite'] == year
            ], 
            key=lambda tup: tup[-1], reverse=True
        ) 
        for year in years
    ]
    
    # Split nodes in big nodes (displayed) and small nodes (summarized in misc)
    big_nodes = [
        node[0]
        for nodes in node_years
        for node in nodes[:nodes_per_year]
    ]
    
    small_nodes = [
        node[0]
        for nodes in node_years
        for node in nodes[nodes_per_year:]
    ]
    
    # Build graph in which small nodes are summarized in misc
    H = nx.subgraph(G, big_nodes).copy()
    
    # Map small nods to corresponding misc node
    small_nodes_mapper = dict() 
    for year, nodes in zip(years, node_years):
        for node, data in nodes[nodes_per_year:]:
            small_nodes_mapper[node] = f'misc_{year}'
    
    # calculate weight of misc nodes
    misc_weights = {f'misc_{year}': 0 for year in years}
    for small_node in small_nodes:
        weight = G.nodes[small_node]['weight']
        misc_node_key = small_nodes_mapper[small_node]
        misc_weights[misc_node_key] += weight
    
    # add misc nodes to graph
    H.add_nodes_from([
        (k, dict(weight=w, bipartite=k.split('_')[-1])) 
        for k, w in misc_weights.items()
    ])
       
    # get edges regarding misc nodes 
    edges_to_merge = [
        (u, v, d)
        for u, v, d in G.edges(data=True)
        if not (u in big_nodes and v in big_nodes)
    ]
        
    # calculate weight of misc edges
    for u, v, d in edges_to_merge:
        u_mapped = small_nodes_mapper.get(u, u)
        v_mapped = small_nodes_mapper.get(v, v)
        if H.has_edge(u_mapped, v_mapped):
            H.edges[u_mapped, v_mapped]['weight'] += d['weight']
        else:
            H.add_edge(u_mapped, v_mapped, weight=d['weight'])   
     
    # filter edges by absolute weight
# Outdated. We filter now by relativ weight to the source and target node. See above.
#     edges = [(u, v) for u, v, data in H.edges(data=True) if data['weight'] < edge_threshold]
#     print(f'Removes {len(edges)} of {len(H.edges)} edges ({len(edges)/len(H.edges)*100:.1f}%).')
#     H.remove_edges_from(edges)
#     print(f'Remaining {len(H.edges)/(len(years)-1)} edges per year mapping')
        
    return H, orig_G

##### Order

In [5]:
def order_nodes_by_weight(config, G):
    return {
        level:
        sorted([n for n, data in G.nodes(data=True) if data['bipartite'] == level], key=lambda n: (
               G.nodes[n]['weight'] 
               if not n.startswith('misc_')
               else -1
           ), reverse=True
        )  # Order by weight
        for level in config['levels']
    }

In [6]:
def order_edges_by_weight(edge_positions):
    return sorted(
        edge_positions.items(), 
        key=lambda tup:(
            -1
            if tup[0][0].startswith('misc_') or tup[0][1].startswith('misc_')
            else tup[-1]['width']
        )
    )

##### Color and label position

In [7]:
def alternating_colors(edge_positions, node_positions, G, level_node_orders):
    """
    sets colors in edge_positions
    """
    # Alternating colors
    for nodes_at_level in level_node_orders.values():
        for idx, node in enumerate(nodes_at_level):
            if idx % 2 == 0:
                node_positions[node]['color'] = '0.7'
                node_positions[node]['v_position'] = 'top'
            else:
                node_positions[node]['color'] = '0.6'
                node_positions[node]['v_position'] = 'bottom'

            for out_edge in G.out_edges(node):
                    edge_positions[out_edge]['color'] = node_positions[node]['color']
                    
    # Special color for misc nodes and corresponding edges
    for node in node_positions:
        if node.startswith('misc_'):
            color ='0.8'
            node_positions[node]['color'] = color
            node_positions[node]['v_position'] = 'center'
            for out_edge in [
                *G.out_edges(node),
                *G.in_edges(node)
            ]:
                    edge_positions[out_edge]['color'] = color

In [8]:
def color_merges_splits(edge_positions, node_positions, H,):
    """
    Coloring merges, splits, etc.
    Colors:
    - merge edges: green
    - split edges: red
    - combined merge and split edges: yellow
    - ignored edges: blue
    A edge is considered a merge/split edge if the source/target node has a degree > 1 of significant edges. 
    An edge is significant if its weight is a least the node weight * threshold of the source and target node.
    Edges that are not significant are ignored edges.
    Edges from or to Misc. nodes are not recolored.
    """
    for edge_key in edge_positions:
        u, v = edge_key
        if u.startswith('misc_') and v.startswith('misc_'):
            continue
        if H.has_edge(u, v):
            is_u_multi = H.out_degree(u) > 1 and not u.startswith('misc_')
            is_v_multi = H.in_degree(v) > 1 and not v.startswith('misc_')
            if is_u_multi and is_v_multi:
                edge_positions[edge_key]['color'] = 'magenta'
                edge_positions[edge_key]['hatch'] = 'X'
            elif is_u_multi:
                # Split
                edge_positions[edge_key]['color'] = 'red'
                edge_positions[edge_key]['hatch'] = '/'
            elif is_v_multi:
                # Merge
                edge_positions[edge_key]['color'] = 'blue'
                edge_positions[edge_key]['hatch'] = '\\'
        else:
            # This is not used if remove insifnificant edges in the first place 
            edge_positions[edge_key]['color'] = 'orange'
            edge_positions[edge_key]['hatch'] = '|'

In [9]:
def color_births_deaths(node_positions, H):
    years = sorted(set(nx.get_node_attributes(H, 'bipartite').values()))
    for node in node_positions:
        is_birth = H.nodes[node]['bipartite'] != years[0] and H.in_degree(node) == 0
        is_death = H.nodes[node]['bipartite'] != years[-1] and H.out_degree(node) == 0
        if is_birth and is_death:
            pass
            # clusters living in one year only are not highlighted if statement below is commented out
            # node_positions[node]['color'] = 'orange'
        elif is_birth:
            node_positions[node]['color'] = 'gold'
        elif is_death:
            node_positions[node]['color'] = 'chocolate'

In [10]:
def color_by_cluster_family(orig_G, node_positions, edge_positions, dataset):
    components = cluster_families(orig_G, threshold=0.15)
    cmap = cluster_family_plt_colors(dataset)
    
    components = components[:20]
    
    order = [c[0] for c in components]
    
    for order_nr, nodes in enumerate(components):
        color = cmap(order_nr)
        for node in nodes:
            if node in node_positions:
                node_positions[node]['color'] = color

        for edge in edge_positions:
            u, v = edge
            if u in nodes or v in nodes:
                 edge_positions[edge]['color'] = color

#### Runner function

In [11]:
def plot_evolution_graph(dataset, config_str, nodes_per_year=50, min_font_size=3):
    global G, orig_G, node_positions, edge_positions
    G, orig_G = load_graph(
        path=f'../../legal-networks-data/{dataset.lower()}/13_cluster_evolution_graph/all_{config_str}.gpickle.gz',
        nodes_per_year=nodes_per_year,
        edge_threshold = .15,
    )
    
    config = calc_config(G)

    level_node_orders = order_nodes_by_weight(config, G)

    node_positions = calc_node_positions(G, config, level_node_orders)
    edge_positions = calc_edge_positions(G, node_positions, config, level_node_orders)

    alternating_colors(edge_positions, node_positions, G, level_node_orders)
    
#     categeories_df = pd.read_csv(f'../{dataset.upper()}-data/cd_8_cluster_categories/all_{config_str}.csv')
#     categories_colors(edge_positions, node_positions, categeories_df)
    
#     color_merges_splits(edge_positions, node_positions, G)
#     color_births_deaths(node_positions, G)

    color_by_cluster_family(orig_G, node_positions, edge_positions, dataset)
            
    plt.rcParams['figure.figsize'] = 8, 11
    plt.rcParams['figure.constrained_layout.use'] = True
    fig = plt.figure()
    ax = fig.add_subplot(111)

    for node, position in node_positions.items():
        draw_node(ax=ax, **position)

    edge_positions_list = order_edges_by_weight(edge_positions)
    for edge, position in edge_positions_list:
        draw_edge(ax=ax, **position)

    plt.yticks(config['level_tick_positions'], [l[:4] for l in config['levels']], fontsize=12)
    plt.xticks([])
    filepath_base = f'../graphics/sankey_{dataset.lower()}_{config_str}'
    if nodes_per_year != 50:
        filepath_base +=f'_miscafter{nodes_per_year or 0}'
    plt.savefig(filepath_base + '.pdf')
    for node, position in node_positions.items():
        community_id = node.split('_')[1]
    #     label = ' '.join(
    #         [x.split('_')[0][:-1] for x in G.nodes[node].get('law_names', '').split(',')[::2]][:3]
    #     )
        draw_label(ax=ax, **position, 
            label=(
                'Miscellaneous' 
                if node.startswith('misc_') else 
                community_id
            ),
            label_rotation= 0 if node_positions[node]['v_position'] == 'center' else 90,
            min_font_size=min_font_size,
        )
    plt.savefig(filepath_base + '_labels.pdf')
    plt.close()

### Drawing the Sankey plots

In [12]:
config_str = '0-0_1-0_-1_a-infomap_n100_m1-0_s0_c1000'

In [13]:
plot_evolution_graph('us_reg', config_str, nodes_per_year=50)

In [14]:
plot_evolution_graph('de_reg', config_str, nodes_per_year=50)

In [15]:
# for config in ['0-0_1-0_-1_a-infomap_m1-0_s0_c1000'] + [
#     f'0-0_1-0_-1_a-infomap_n{runs}_m1-0_s0_c1000' for runs in list(range(10,150+1,10)) + [200]
# ]:
#     plot_evolution_graph('us', config)
#     print(config, 'done')

In [16]:
# plot_evolution_graph('us_reg', config_str, nodes_per_year=500, min_font_size=None)

In [17]:
# plot_evolution_graph('de', config_str, nodes_per_year=500, min_font_size=None)

### Inspecting the Miscellaneous category

In [18]:
def inspect_misc(G, orig_G, dataset):
    misc_content = defaultdict(list)
    for n in sorted(set(orig_G.nodes) - set(G.nodes)):
        year = n.split('_')[0]
        misc_content[year].extend(
            orig_G.nodes[n]['nodes_contained'].split(',')
        )
    
    content = ''
    for year in sorted(misc_content.keys()):
        H = nx.read_gpickle(f'../../legal-networks-data/{dataset.lower()}/4_crossreference_graph/seqitems/{year}.gpickle.gz')
        H = hierarchy_graph(H)
        sum_tokens = sum(H.nodes[n]['tokens_n'] for n in misc_content[year])
        content += f'{year} | Size of misc in tokens: {sum_tokens}\n\n'
        for n in sorted(misc_content[year], key=lambda n: -H.nodes[n]['tokens_n']):
            content += f"{H.nodes[n]['tokens_n']:10} | " + get_heading_path(H, n) + '\n'
        content += '\n\n\n'

    with open(f'../results/sankey_misc_inspection_{dataset.lower()}.txt', 'w') as f:
        f.write(content)

# for dataset in ['us', 'de']:
#     G, orig_G = load_graph(
#         path=f'../../legal-networks-data/{dataset.lower()}/13_cluster_evolution_graph/all_{config_str}.gpickle.gz',
#         nodes_per_year=50,
#         edge_threshold = .15,
#     )
#     inspect_misc(G, orig_G, dataset)
    

### Further layouting and coloring options (not used in the paper)

#### Coloring with tab20

In [19]:
def add_colors(G, node_positions, edge_positions, level_node_orders, config, cmap=plt.get_cmap('tab20')):
    """
    Call e.g. 
    > node_positions, edge_positions = add_colors(G, node_positions, edge_positions, level_node_orders, config, cmap=plt.get_cmap('tab20'))
    """
    cmap_counter = 0

    for nodes in [level_node_orders[level] for level in config['levels']]:
        for node in nodes:
            in_edges = G.in_edges(node)
            if not in_edges:
                # If no in_edges: get new color
                color = cmap(cmap_counter % 20)
                cmap_counter += 1
            else:
                max_in_edge = max(in_edges, key=lambda edge: G.edges[edge]['weight'])
                color = edge_positions[max_in_edge]['color']

            node_positions[node]['color'] = color
            for out_edge in G.out_edges(node):
                edge_positions[out_edge]['color'] = color
    return node_positions, edge_positions

#### Sugiyama nodes order

Note: This requires the `igraph` library to be installed, which is not among the repository requirements and hence may have to be installed separately.

In [20]:
def get_sugiyama_order(G):
    """
    Use e.g.
    > ordering = get_sugiyama_order(G) 
    > level_node_orders = {
    >     level:
    >     sorted([n for n, data in G.nodes(data=True) if data['bipartite'] == level], key=lambda n: ordering.index(n)) 
    >     for level in config['levels']
    > }
    """
    import igraph
    
    nodes = list(G.nodes)
    node_weights = [G.nodes[n]['weight'] for n in nodes]
    node_bipartite = [G.nodes[n]['bipartite'] for n in nodes]
    layers = sorted(set(node_bipartite))
    
    g = igraph.Graph(directed=True)
    g.add_vertices(list(G.nodes))
#     g.vs['weight'] = node_weights
#     g.vs['bipartite'] = node_bipartite

    g.add_edges(G.edges())
    
    layout = g.layout_sugiyama(layers=[layers.index(x) for x in node_bipartite], weights=node_weights)
    
    nodes_layout = [(n, *l) for n, l in zip(nodes, layout)]
    
    y_positions = list(zip(*list(layout)))[1]
    assert max(y_positions) == len(layers) - 1
    
    ordering = [[] for _ in range(len(layers))]
    
    for node, x, layer in sorted(nodes_layout, key=lambda tup: tup[1]):
        ordering[int(layer)].append(node)
    return list(itertools.chain(*ordering))

#### Coloring by content

In [21]:
def load_clusterings(config, path, snapshots, dataset):
    clusterings = []
    for snapshot in snapshots:
        clustering = get_clustering_result(
            f'{path}/{snapshot}_{config}.json',
            dataset,
            'seqitems'
        )
        add_community_to_graph(clustering)
        add_headings_path(clustering.graph)
        clusterings.append(clustering)
        print('\r' + snapshot, end='')
    return clusterings


def get_cluster_weights_for_pattern(pattern_str, pattern_key, clusterings, snapshots, G):
    pattern = regex.compile(r'[^|]*' + pattern_str + r'[^|]*', flags=regex.IGNORECASE)
    # pattern_debug = regex.compile(r'[^|]*' + pattern_str + r'[^|^/]*', flags=regex.IGNORECASE)

    for snapshot, clustering in zip(snapshots, clusterings):
        texts_matches = [Counter() for cluster in range(len(clustering.communities))]
        tokens_n_sizes = [0    for cluster in range(len(clustering.communities))]
        tokens_n_matches = [0   for cluster in range(len(clustering.communities))]
        
        for node in clustering.graph.nodes:
            
            if clustering.graph.nodes[node].get('type') == 'seqitem' and 'community' in clustering.graph.nodes[node]:
                community_idx = clustering.graph.nodes[node]['community']
                heading_path = clustering.graph.nodes[node]['heading_path']
                tokens_n_sizes[community_idx] += clustering.graph.nodes[node]['tokens_n']
                
                if pattern.match(heading_path):
                    
                    tokens_n_matches[community_idx] += clustering.graph.nodes[node]['tokens_n']
                    text = node.split('_')[0] # 1 for de and 0 for us
                    texts_matches[community_idx].update(
                        **{text: clustering.graph.nodes[node]['tokens_n'] }
                    )


#         weights = [
#             (len(pattern.findall(text)) / cnt) if cnt else 0
#             for text, cnt in zip(law_name_texts, node_counts)
#         ]

        for idx, tokens_n_size, tokens_n_match, texts_match in zip(
            range(len(tokens_n_sizes)), tokens_n_sizes, tokens_n_matches, texts_matches
        ):
            node = f'{snapshot}_{idx}'
            if G.has_node(node):
                G.nodes[f'{snapshot}_{idx}'][pattern_key] = tokens_n_match/tokens_n_size if tokens_n_size else 0
                if 'labels' not in  G.node[f'{snapshot}_{idx}']:
                    G.nodes[f'{snapshot}_{idx}']['labels'] = Counter()
                G.nodes[f'{snapshot}_{idx}']['labels'].update(texts_match)

##### Sample usage (DE data)

- Straf
- Proze(ss|ß)
- steuer
- statistik
- umwelt
- haftung (not so good)
- gesellschaft
- kapital
- gewerbe
- arbeit

In [22]:
# snapshots = [f'{year}' for year in range(1994, 2019)]
# clusterings_path = '../US-data/cd_2_cluster_results'
# evolution_path = '../US-data/cd_4_cluster_evolution_graph/'
# dataset= 'us'

In [23]:
# snapshots = [f'{year}-01-01' for year in range(1994, 2019)]
# clusterings_path = '../DE-data/cd_2_cluster_results'
# evolution_path = '../DE-data/cd_4_cluster_evolution_graph/'
# dataset= 'de'

In [24]:
# clusterings = load_clusterings(cluster_config, clusterings_path, snapshots, dataset)

In [25]:
# # DE
# get_cluster_weights_for_pattern(
#     '(straf|ordnungsw)',
#     'crim',
#     clusterings, snapshots, G
# )
# get_cluster_weights_for_pattern(
#     'steuer',
#     'tax',
#     clusterings, snapshots, G
# )
# get_cluster_weights_for_pattern(
#     '(umwelt|engerie)',
#     'environment',
#     clusterings, snapshots, G
# )
# get_cluster_weights_for_pattern(
#     'sozial',
#     'social',
#     clusterings, snapshots, G
# )

In [26]:
# # US
# get_cluster_weights_for_pattern(
#     '(environ|conserva)',
#     'environment',
#     clusterings, snapshots, G
# )
# get_cluster_weights_for_pattern(
#     'Public\sHealth\s.{1,3}\sWelfare',
#     '42',
#     clusterings, snapshots, G
# )

In [27]:
# for node in node_positions:
#     colors = [
#         plt.get_cmap('Reds')(G.nodes[node]['crim']),
#         plt.get_cmap('Blues')(G.nodes[node].get('42', 0)),
#         plt.get_cmap('Oranges')(G.nodes[node].get('environment', 0)),
#         plt.get_cmap('Greens')(G.nodes[node]['social'])
#     ]
#     color = [
#         max(min(sum(c_channel)-(len(c_channel)-1), 1), 0)
#         for c_channel in zip(*colors)
#     ]
    
#     if node.startswith('misc_'):
#         color ='0.8'

#     node_positions[node]['color'] = color
#     for out_edge in G.out_edges(node):
#             edge_positions[out_edge]['color'] = color

#### Coloring by category

In [28]:
def categories_colors(edge_positions, node_positions, categeories_df):
    for edge_key in edge_positions:
        u, v = edge_key
    
        if u.startswith('misc') or v.startswith('misc'):
            continue
            
        edge_data = categeories_df[
            (categeories_df.from_cluster == u)
            & 
            (categeories_df.to_cluster == v)
        ]
        
        
        if len(edge_data) == 1:
            event_category = edge_data.iloc[0]['event_category']
        else:
            assert len(edge_data) == 0
            event_category = None
        
        if event_category == 'split':
            edge_positions[edge_key]['color'] = 'red'
        elif event_category == 'merge':
            edge_positions[edge_key]['color'] = 'green'
        elif event_category == 'splerge':
            edge_positions[edge_key]['color'] = 'yellow'

### End