In [None]:
import json
import jsonlines
import networkx as nx
import hypernetx as hnx
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from refined.inference.processor import Refined


In [None]:
refined = Refined.from_pretrained(model_name='wikipedia_model_with_numbers',
                                  entity_set="wikipedia")

In [None]:
def save_json(data, filepath=r'new_data.json'):
    with open(filepath, 'w') as fp:
        json.dump(data, fp, indent=4)

In [None]:
dev_reader = jsonlines.open(r'data/raw/RAMS/dev.jsonlines')
test_reader = jsonlines.open(r'data/raw/RAMS/test.jsonlines')
train_reader = jsonlines.open(r'data/raw/RAMS/train.jsonlines')

In [None]:
def update_event_span(link, start_span, end_span):
    trigger_span = link[0] # a list of [start, end]
    if trigger_span[0] < start_span:
        start_span = trigger_span[0]
    if trigger_span[1] > end_span:
        end_span = trigger_span[1]
    argument_span = link[1] # a list of [start, end]
    if argument_span[0] < start_span:
        start_span = argument_span[0]
    if argument_span[1] > end_span:
        end_span = argument_span[1]
    return start_span, end_span

In [None]:
def merge_sentences(datum):
    sentence_list = [" ".join(sentence_word_list) for sentence_word_list in datum['sentences']] # merge the words into sentences
    paragraph = " ".join(sentence_list)
    return paragraph

def merge_events(datum):
    words_flattened = [word for sentence in datum['sentences'] for word in sentence]
    triggers = datum['evt_triggers']
    trigger_type_dict = {}
    for trigger_datum in triggers:
        trigger_span = trigger_datum[:2]
        trigger_word = " ".join(words_flattened[trigger_span[0]:trigger_span[1]+1])
        trigger_type = trigger_datum[2][0][0]
        trigger_type_dict[trigger_word] = trigger_type
    links = datum['gold_evt_links']
    events = {}
    start_span = 0
    end_span = 0
    for link in links:
        trigger_span = link[0] # a list of [start, end]
        trigger_word = " ".join(words_flattened[trigger_span[0]:trigger_span[1]+1]) # a string
        argument_span = link[1] # a list of [start, end]
        argument_word = " ".join(words_flattened[argument_span[0]:argument_span[1]+1]) # a string
        argument_role = link[2] # a string
        trigger_type = trigger_type_dict[trigger_word]
        if trigger_word not in events.keys():
            events[trigger_word] = {
                "trigger": trigger_word,
                "trigger_span": trigger_span,
                "trigger_type": trigger_type,
                "arguments": [
                    {
                        "argument_id": argument_word,
                        "argument_word": argument_word,
                        "argument_role": argument_role,
                        "argument_span": argument_span
                    }
                ],
            }
        else:
            events[trigger_word]['arguments'].append({
                "argument_id": argument_word,
                "argument_word": argument_word,
                "argument_role": argument_role,
                "argument_span": argument_span
            })
        # update event span
        start_span, end_span = update_event_span(link, start_span, end_span)
        events[trigger_word]['paragraph'] = datum['sentences']
    return list(events.values())

def link_entities(events, paragraph):
    spans = refined.process_text(paragraph)
    for span in spans:
        entity_word = span.text
        for event in events:
            for argument in event['arguments']:
                if argument['argument_word'] == entity_word:
                    if span.predicted_entity != None and span.predicted_entity.wikidata_entity_id != None:
                        entity_id = span.predicted_entity.wikidata_entity_id
                        entity_title = span.predicted_entity.wikipedia_entity_title
                        argument['argument_id'] = entity_id
                        argument['entity_title'] = entity_title
                    argument['entity_type'] = span.coarse_mention_type
    return events

def transform_dataset(dataset):
    transformed_dataset = {}
    for index, datum in enumerate(dataset):
        print("{}/{}".format(index, len(dataset)))
        paragraph = merge_sentences(datum)
        events = merge_events(datum)
        if events == []: continue
        events = link_entities(events, paragraph)
        doc_key = datum['doc_key']
        source_url = datum['source_url']
        if doc_key not in transformed_dataset.keys():
            transformed_dataset[doc_key] = {
                "doc_id": doc_key,
                "source_url": source_url,
                "events": []
            }
        transformed_dataset[doc_key]['events'] += events
    return list(transformed_dataset.values())

In [None]:
dev_reader = jsonlines.open(r'data/raw/RAMS/dev.jsonlines')
for datum in dev_reader:
    doc_id = datum['doc_key']
    if doc_id == "nw_RC0da9ca01673da1e2a47f6ccf9d239cbde98f30122f50c5ced8fa4743":
        paragraph = merge_sentences(datum)
        events = merge_events(datum)
        pprint(events)
        break


In [None]:
# dataset = [datum for datum in dev_reader] + [datum for datum in test_reader] + [datum for datum in train_reader]
dataset = [datum for datum in dev_reader]
transformed_dataset = transform_dataset(dataset)
save_json(transformed_dataset, r'data/result/RAMS/events.json')

In [None]:
def disambiguate(docs):
    nodes_dict = {}
    hyper_edges_dict = {}
    links = []
    for doc in docs:
        doc_id = doc['doc_id']
        doc_url = doc['source_url']
        for event in doc['events']:
            arguments = event['arguments']
            # create an entity node for each argument
            for argument in arguments:
                argument_id = argument['argument_id']
                argument_word = argument['argument_word']
                argument_title = argument['entity_title'] if 'entity_title' in argument else argument_word
                argument_entity_type = argument['entity_type'] if 'entity_type' in argument else "None"
                argument_span = argument['argument_span']
                argument_role = argument['argument_role']
                if argument_id not in nodes_dict.keys():
                    nodes_dict[argument_id] = {
                        "id": argument_id, 
                        "title": argument_title,
                        "entity_type": argument_entity_type,
                        "type": "entity",
                        "argument_role": argument_role,
                        "mentions": [
                            {
                                "doc_id": doc_id,
                                "mention": argument_word,
                                "span": {'start': argument_span[0], 'end': argument_span[1]}
                            }
                        ]
                    }
                else:
                    nodes_dict[argument_id]['mentions'].append(
                        {
                            "doc_id": doc_id,
                            "mention": argument_word,
                            "span": {'start': argument_span[0], 'end': argument_span[1]}
                        }
                    )
            argument_ids = [argument['argument_id'] for argument in arguments]
            if any([argument_id == None for argument_id in argument_ids]):
                print(doc_id, argument_ids)
            sorted_argument_ids = sorted(argument_ids)
            # create hyperedge 
            trigger_id = event['trigger'] 
            trigger_type = event['trigger_type']
            hyper_edge_id = trigger_id  + "-" + "-".join(sorted_argument_ids)
            if hyper_edge_id not in hyper_edges_dict.keys():
                hyper_edges_dict[hyper_edge_id] = {
                    'id': hyper_edge_id,
                    'type': "hyper_edge",
                    "trigger": trigger_id,
                    "trigger_type": trigger_type,
                    "arguments": sorted_argument_ids,
                    "mentions": [
                        {
                            "doc_id": doc_id,
                            "paragraph": event['paragraph']
                        }
                    ]
                }
            else:
                hyper_edges_dict[hyper_edge_id]['mentions'].append(
                    {
                        "doc_id": doc_id,
                        "paragraph": event['paragraph']
                    }
                )
            for argument_id in argument_ids:
                links.append((hyper_edge_id, argument_id))
    return nodes_dict, hyper_edges_dict, links

def merge_RAMS(dataset):
    nodes_dict, hyper_edges_dict, links = disambiguate(dataset)
    B = nx.Graph()
    B.add_nodes_from(list(hyper_edges_dict.keys()), bipartite=0)
    B.add_nodes_from(list(nodes_dict.keys()), bipartite=1)
    B.add_edges_from(links)
    return B, nodes_dict, hyper_edges_dict, links


In [None]:
B, nodes_dict, hyper_edges_dict, links = merge_RAMS(transformed_dataset)

H = hnx.Hypergraph.from_bipartite(B)
list(H.shape)

In [None]:
event_hgraph_data = nx.node_link_data(B)
save_json(event_hgraph_data, r'data/result/RAMS/biHgraph_dev/hgraph.json')
save_json(nodes_dict, r'data/result/RAMS/biHgraph_dev/nodes.json')
save_json(hyper_edges_dict, r'data/result/RAMS/biHgraph_dev/hyperedges.json')

In [None]:
def plot_degree_distribution(HG, fit_line=False):
    degree_sequence = [HG.degree(node) for node in HG.nodes]
    degree_counts = [(degree, degree_sequence.count(degree)) for degree in set(degree_sequence)]
    x, y = zip(*degree_counts)
        
    # fit line
    if fit_line:
        filter_degree = 15
        filtered_degree_sequence = list(filter(lambda degree: degree < filter_degree, degree_sequence))
        filtered_degree_counts = [(degree, degree_sequence.count(degree)) for degree in set(filtered_degree_sequence)]
        filtered_x, filtered_y = zip(*filtered_degree_counts)
        log_x = np.log10(filtered_x)
        log_y = np.log10(filtered_y)
        slope, intercept = np.polyfit(log_x, log_y, 1)
        print("slope:", slope, "intercept:", intercept)
        x_vals = np.array([min(filtered_x), max(filtered_x)])
        y_vals = 10**(intercept + slope*np.log10(x_vals))
        plt.plot(x_vals, y_vals, '--')
    
        
    plt.scatter(x, y)
    # plt.xscale("log")
    # plt.yscale("log")
    plt.xlabel('Degree')
    plt.ylabel('Probability')
    plt.show()

In [None]:
print(list(H.nodes))
removed_nodes = [node for node in H.nodes if H.degree(node) == 1]
SH = H.remove_nodes(removed_nodes)
print(SH.shape)

In [None]:
def transform_frontend(nodes, links, nodes_dict, hyper_edges_dict):
    res_nodes = []
    res_links = []
    for node in nodes:
        if node in nodes_dict:
            res_nodes.append(nodes_dict[node])
        else:
            res_nodes.append(hyper_edges_dict[node])
    for link in links:
        source = link[0]
        target = link[1]
        res_links.append({
            "source": source,
            "target": target,
        })
    print(len(res_nodes))
    return {
        "nodes": res_nodes, 
        "links": res_links
    }

In [None]:
BH = H.bipartite()
network = transform_frontend(list(BH.nodes), list(BH.edges), nodes_dict, hyper_edges_dict)
save_json(network, 'data/result/RAMS/dev_subgraph.json')

In [None]:
event_hgraph_data = nx.node_link_data(BH)
save_json(event_hgraph_data, r'event_network_data.json')