In [None]:
import networkx as nx
from community import community_louvain
import json
from pybrat.parser import BratParser, Entity, Event, Example, Relation
import matplotlib.pyplot as plt
from pprint import pprint
import numpy as np
from collections import defaultdict

In [None]:
def read_brat_data(data_path):
    brat = BratParser(error="ignore")
    brat_data = brat.parse(data_path)
    return brat_data

In [None]:
def brat_data_to_network(data_path):
    G = nx.Graph()
    nodes_dict = {} 
    links = defaultdict(lambda: defaultdict(int))
    doc_count = 0
    argument_num_dict = defaultdict(int)
    for doc in brat_data:
        doc_count += 1
        origin_id_to_new_id_dict = {}
        entity_umls = json.load(open(data_path + doc.id + r'.json'))
#         print(doc.id)
        for entity in doc.entities:
            # TODO: add candidate selection. For now just choosing the first one.
            entity_ui = entity.id # reassign node id to either doc_id + original id or CUI
            if len(entity_umls[entity.id]) != 0: 
                entity_ui = entity_umls[entity.id][0]['ui']
            else:
                # TODO: consider partial matching to disambiguiate entities without CUIs
                entity_ui = doc.id + "-" + entity_ui
            origin_id_to_new_id_dict[entity.id] = entity_ui

            if entity_ui not in nodes_dict.keys():
                nodes_dict[entity_ui] = {
                    "id": entity_ui,
                    "type": "entity",
                    "mentions": [
                        {
                            "doc_id": doc.id, 
                            "mention": entity.mention, 
                            "span": {'start': entity.spans[0].start, 'end': entity.spans[0].end}
#                             "span": entity.spans
                        }
                    ],
                }
            else:
                nodes_dict[entity_ui]["mentions"].append(
                    {
                        "doc_id": doc.id, 
                        "mention": entity.mention, 
                        "span": {'start': entity.spans[0].start, 'end': entity.spans[0].end}


#                         "span": [entity.spans.start, entity.spans.end]
#                         "span": entity.spans
                    }
                )
        # each event is treated as a hyper-edge. 
        # First create a hyper edge node, then connect the hyper edge node with the arguments
        # The node id needs to be independent, but the node type is trigger id 
        for event in doc.events:
            trigger_id = origin_id_to_new_id_dict[event.trigger.id]
            origin_id_to_new_id_dict[event.id] = trigger_id # account for nested events
            argument_ids = list(map(lambda argument: origin_id_to_new_id_dict[argument.id], event.arguments))
            
            sorted_argument_ids = sorted(argument_ids)
            # create hyper edge node
            hyper_edge_node_id = trigger_id + "-" + "-".join(sorted_argument_ids)

            if hyper_edge_node_id not in nodes_dict.keys():
                nodes_dict[hyper_edge_node_id] = {
                    "id": hyper_edge_node_id,
                    "type": "hyper_edge",
                    "trigger": trigger_id,
                    "arguments": sorted_argument_ids,
                    "mentions": [
                        {
                            "doc_id": doc.id, 
                            # TODO: add sentence span
                            # "mention": entity.mention, 
                            # "span": [entity.spans]
                        }
                    ],
                }
            else:
                nodes_dict[hyper_edge_node_id]["mentions"].append(
                    {
                        "doc_id": doc.id,
                        # TODO: add sentence span
                    }
                )

            # add links between hyper edge node and arguments
            for argument_id in argument_ids:
                links[hyper_edge_node_id][argument_id] = 1

            # argument_num_dict[len(argument_ids)].append(trigger_id)
            argument_num_dict[len(argument_ids)] += 1
#                 links.append((trigger_id, argument_id, {'attr': 'someAttr'}))
#         if doc_count == 10: break

    pprint(argument_num_dict)
    # turn overlapping links into link length
    links_as_list = []
    for hyper_edge_node_id, argument_ids in links.items():
        for argument_id in argument_ids.keys():
            links_as_list.append((hyper_edge_node_id, argument_id))
    
    # remove nodes that do not have links
    G.add_nodes_from([(node_id, node_attribute_dict) for node_id, node_attribute_dict in nodes_dict.items()])
    G.add_edges_from(links_as_list)
    
    G.remove_nodes_from(list(n for n in G.nodes() if G.degree(n) == 0))

    print(G.number_of_nodes(), G.number_of_edges())

    return G


In [None]:
data_path = (r'brat-1.3p1/data/all-brat/')
brat_data = read_brat_data(data_path)
event_network = brat_data_to_network(data_path)

In [None]:
def draw_network(G):
    # print(G.nodes().data())
    colors = list(map(lambda node: 'blue' if node[1]['type'] == 'entity' else 'black', G.nodes().data()))
    node_sizes = list(map(lambda node: 100+G.degree(node), G.nodes()))
    options = {
        "node_color": colors,
        "node_size": node_sizes,
        "width": 0.5,
        "with_labels": False,
        "pos": nx.spring_layout(G, k=0.15)
    }
    fig = plt.figure(1, figsize=(12, 12), dpi=60)
    nx.draw(G, **options)
    plt.show()


In [None]:
def plot_degree_distribution(G, fit_line=True):
    degree_sequence = [G.degree(node) for node in G.nodes()]
    degree_counts = [(degree, degree_sequence.count(degree)) for degree in set(degree_sequence)]
    x, y = zip(*degree_counts)
        
    # fit line
    if fit_line:
        filter_degree = 15
        filtered_degree_sequence = list(filter(lambda degree: degree < filter_degree, degree_sequence))
        filtered_degree_counts = [(degree, degree_sequence.count(degree)) for degree in set(filtered_degree_sequence)]
        filtered_x, filtered_y = zip(*filtered_degree_counts)
        log_x = np.log10(filtered_x)
        log_y = np.log10(filtered_y)
        slope, intercept = np.polyfit(log_x, log_y, 1)
        print("slope:", slope, "intercept:", intercept)
        x_vals = np.array([min(filtered_x), max(filtered_x)])
        y_vals = 10**(intercept + slope*np.log10(x_vals))
        plt.plot(x_vals, y_vals, '--')
    
        
    plt.scatter(x, y)
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel('Degree')
    plt.ylabel('Probability')
    plt.show()

In [None]:
def get_k_highest_degree_node(G, k=0):
    # get the degree of each node
    degrees = dict(G.degree())
    largest_degrees = sorted(degrees, key=degrees.get, reverse=True)
    # get the data of the k nodes with the largest degree
    data = [G.nodes[node] for node in largest_degrees]
#     for k, node in enumerate(data[:50]):
#         if node['type'] == 'event':
#             print(node['id'], node['type'], degrees[largest_degrees[k]])
#         else:
#             mentions = list(map(lambda mention_data: mention_data['mention'], node['mentions']))
#             print(node['id'], node['type'], degrees[largest_degrees[k]], len(mentions))


        
    print(f"The node with the {k} highest degree is {largest_degrees[k]}, with degree {degrees[largest_degrees[k]]}")
#     pprint(data[k])

In [None]:
def plot_degree_list_bar(G):
    # parameters
    color_map = {'entity': 'r', 'event': 'b'}
    degrees = dict(G.degree())
    k = 50
    
    # prepare bar data
    # x
    node_list = sorted(degrees, key=degrees.get, reverse=True)[:k]
    for node in node_list:
        node_data = G.nodes[node]
        print(node_data['id'], node_data['type'], G.degree(node))
    # y
    degree_list = [G.degree(node) for node in node_list]
    # color
    type_list = [G.nodes[node]['type'] for node in node_list]
    color_list = [color_map[type] for type in type_list]
    
    # plot
    fig,a = plt.subplots()

    a.bar(node_list, degree_list, color=color_list, edgecolor='white', linewidth=1)
    a.xaxis.set_visible(False)

    # remove x-axis label
#     ax.set(xlabel=None)

    
    # add a legend for the color map
    legend_list = [plt.Rectangle((0,0),1,1,color=color_map[node_type]) for node_type in color_map.keys()]
    plt.legend(legend_list, color_map.keys())


    # show the plot
    plt.show()

In [None]:
def run_community_detection(G, alg='louvain'):
    if alg == 'louvain':
        return community_louvain.best_partition(G, weight='strength')

In [None]:
def community_layout(g, partition):
    """
    Compute the layout for a modular graph.


    Arguments:
    ----------
    g -- networkx.Graph or networkx.DiGraph instance
        graph to plot

    partition -- dict mapping int node -> int community
        graph partitions


    Returns:
    --------
    pos -- dict mapping int node -> (float x, float y)
        node positions

    """

    pos_communities = _position_communities(g, partition, scale=3.)

    pos_nodes = _position_nodes(g, partition, scale=1.)

    # combine positions
    pos = dict()
    for node in g.nodes():
        pos[node] = pos_communities[node] + pos_nodes[node]

    return pos

def _position_communities(g, partition, **kwargs):

    # create a weighted graph, in which each node corresponds to a community,
    # and each edge weight to the number of edges between communities
    between_community_edges = _find_between_community_edges(g, partition)

    communities = set(partition.values())
    hypergraph = nx.DiGraph()
    hypergraph.add_nodes_from(communities)
    for (ci, cj), edges in between_community_edges.items():
        hypergraph.add_edge(ci, cj, weight=len(edges))

    # find layout for communities
    pos_communities = nx.spring_layout(hypergraph, **kwargs)

    # set node positions to position of community
    pos = dict()
    for node, community in partition.items():
        pos[node] = pos_communities[community]

    return pos

def _find_between_community_edges(g, partition):

    edges = dict()

    for (ni, nj) in g.edges():
        ci = partition[ni]
        cj = partition[nj]

        if ci != cj:
            try:
                edges[(ci, cj)] += [(ni, nj)]
            except KeyError:
                edges[(ci, cj)] = [(ni, nj)]

    return edges

def _position_nodes(g, partition, **kwargs):
    """
    Positions nodes within communities.
    """

    communities = dict()
    for node, community in partition.items():
        try:
            communities[community] += [node]
        except KeyError:
            communities[community] = [node]

    pos = dict()
    for ci, nodes in communities.items():
        subgraph = g.subgraph(nodes)
        pos_subgraph = nx.spring_layout(subgraph, **kwargs)
        pos.update(pos_subgraph)

    return pos

def visualize_community(G, partition):
    # to install networkx 2.0 compatible version of python-louvain use:
    # pip install -U git+https://github.com/taynaud/python-louvain.git@networkx2
    from community import community_louvain

    partition = community_louvain.best_partition(G)
    pos = community_layout(G, partition)

    fig = plt.figure(1, figsize=(12, 12), dpi=60)
#     pprint(partition)
    nx.draw(G, pos, node_color=list(partition.values()), node_size=40)


    plt.show()
    return

In [None]:
data_path = (r'brat-1.3p1/data/all-brat/')
brat_data = read_brat_data(data_path)
event_network = brat_data_to_network(data_path)
print(event_network.number_of_nodes(), event_network.number_of_edges())
# pprint(communities)


In [None]:
# filter out events that only occur once
# filtered_nodes = [n for n in event_network.nodes() if event_network.nodes[n]['type']=='entity']
# filtered_nodes = [event_network.nodes[n]['type'] for n in event_network.nodes()]


filtered_nodes = [n for n in event_network.nodes() if event_network.nodes[n]['type']=='hyper_edge' and event_network.degree(n) <= 1]
# pprint(filtered_nodes)
# filtered_nodes = [n for n in event_network.nodes() if event_network.degree(n) <= 1]
event_network.remove_nodes_from(filtered_nodes)
event_network.remove_nodes_from([n for n in event_network.nodes() if event_network.degree(n) == 0])

print(event_network.number_of_nodes(), event_network.number_of_edges())


# draw_network(event_network)
# plot_degree_distribution(event_network)
# plot_degree_list_bar(event_network)
# communities = run_community_detection(event_network)

# visualize_community(event_network, communities)
# get_k_highest_degree_node(event_network, 5)

In [None]:
def save_json(data, filepath=r'new_data.json'):
    with open(filepath, 'w') as fp:
        json.dump(data, fp, indent=4)
event_network_data = nx.node_link_data(event_network)
pprint(event_network_data)
save_json(event_network_data, r'event_network_data.json')

In [None]:
def save_hyper_edges(G, filepath=r'hyper_edges.txt'):
    hyper_edges = [n for n in event_network.nodes() if event_network.nodes[n]['type']=='hyper_edge']
    entities = [n for n in event_network.nodes() if event_network.nodes[n]['type']=='entity']
    # assuming no event-event connection
    node_to_index = {node: i+1 for i, node in enumerate(entities)}
    
    save_json(node_to_index, 'node_to_index.json')
    with open('hyper_edges.txt', 'w', encoding='utf-8') as f:
        for hyper_edge in hyper_edges:
            # assuming no event-event connection
            edge_nodes = [str(node_to_index[n]) for n in G.neighbors(hyper_edge) if event_network.nodes[n]['type']=='entity']
#             edge_nodes = [n for n in G[hyper_edge]]
            
#             if len(edge_nodes) == 1:
#                 pprint(hyper_edge)
#                 pprint([event_network.nodes[n] for n in edge_nodes])

            if len(edge_nodes) > 2:
                print(len(edge_nodes))
            line = ','.join(edge_nodes)
            f.write(line)
            f.write('\n')
    f.close()
        

In [None]:
save_hyper_edges(event_network)