# Parse NOW corpus text files

In [None]:
import glob
import re
import spacy
import pickle
from collections import defaultdict
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.corpus import wordnet
from multiprocessing import Pool

filter_strs defines the scope of our dataset.

Processing US news from January 2016 to June 2018 would take about 1 hour.

Processing all news from January 2016 to June 2018 would take half a day.

In [None]:
load_path = '/data/NOW/text/'
write_path = '/data/ent2ent/han/incremental/'
filter_strs = ['18*', 'text_18*', '17*', 'text_17*', '16*', 'text_16*'] # change this to '*' for all files
year_base = 16

Here we manually define some entities and their aliases.

We are to consider every pair within the combination of these entities.

In [None]:
entity_aliases_dict = {'U.S.': {'U.S.', 'US', 'USA', 'Trump', 'Obama'},
                       'China': {'China', 'Chinese', 'Xi'},
                       'Syria': {'Syria', 'Syrian', 'Assad'},
                       'France': {'France', 'French', 'Macron', 'Hollande'},
                       'Germany': {'Germany', 'German', 'Merkel'},
                       'Canada': {'Canada', 'Canadian', 'Trudeau'},
                       'Russia': {'Russia', 'Russian', 'Putin'},
                       'India': {'India', 'Indian', 'Modi'},
                       'U.K.': {'U.K.', 'UK', 'British', 'Britain', 'Cameron'},
                       'Japan': {'Japan', 'Japanese', 'Abe'},
                       'Iran': {'Iran', 'Iranian', 'Khamenei', 'Rouhani'},
                       'Israel': {'Israel', 'Israeli', 'Netanyahu'}}
entities_list = list(entity_aliases_dict.keys())
interest_pair_list = []
for i in range(len(entities_list) - 1):
    for j in range(i + 1, len(entities_list)):
        interest_pair_list.append((entities_list[i], entities_list[j]))

In [None]:
nlp = spacy.load('en_core_web_sm')
wnl = WordNetLemmatizer()

Below are some util functions for Subject-Verb-Object (SVO) extraction.

Code adapted from: https://github.com/NSchrading/intro-spacy-nlp/blob/master/subject_object_extraction.py

In [None]:
SUBJECTS = ["nsubj", "nsubjpass", "csubj", "csubjpass", "agent", "expl"]
OBJECTS = ["dobj", "dative", "attr", "oprd"]
REL_PRONS = ["that", "who", "which", "whom", "whose", "where", "when", "what", "why"]

def getSubsFromConjunctions(subs):
    moreSubs = []
    for sub in subs:
        # rights is a generator
        rights = list(sub.rights)
        rightDeps = {tok.lower_ for tok in rights}
        if "and" in rightDeps:
            moreSubs.extend([tok for tok in rights if tok.dep_ in SUBJECTS or tok.pos_ == "NOUN"])
            if len(moreSubs) > 0:
                moreSubs.extend(getSubsFromConjunctions(moreSubs))
    return moreSubs

def getObjsFromConjunctions(objs):
    moreObjs = []
    for obj in objs:
        # rights is a generator
        rights = list(obj.rights)
        rightDeps = {tok.lower_ for tok in rights}
        if "and" in rightDeps:
            moreObjs.extend([tok for tok in rights if tok.dep_ in OBJECTS or tok.pos_ == "NOUN"])
            if len(moreObjs) > 0:
                moreObjs.extend(getObjsFromConjunctions(moreObjs))
    return moreObjs

def getVerbsFromConjunctions(verbs):
    moreVerbs = []
    for verb in verbs:
        rightDeps = {tok.lower_ for tok in verb.rights}
        if "and" in rightDeps:
            moreVerbs.extend([tok for tok in verb.rights if tok.pos_ == "VERB"])
            if len(moreVerbs) > 0:
                moreVerbs.extend(getVerbsFromConjunctions(moreVerbs))
    return moreVerbs

def findSubs(tok):
    head = tok.head
    while head.pos_ != "VERB" and head.pos_ != "NOUN" and head.head != head:
        head = head.head
    if head.pos_ == "VERB":
        subs = [tok for tok in head.lefts if tok.dep_ == "SUB"]
        if len(subs) > 0:
            verbNegated = isNegated(head)
            subs.extend(getSubsFromConjunctions(subs))
            return subs, verbNegated
        elif head.head != head:
            return findSubs(head)
    elif head.pos_ == "NOUN":
        return [head], isNegated(tok)
    return [], False

def isNegated(tok):
    negations = {"no", "not", "n't", "never", "none"}
    for dep in list(tok.lefts) + list(tok.rights):
        if dep.lower_ in negations:
            return True
    return False

def getObjsFromPrepositions(deps):
    objs = []
    for dep in deps:
        if dep.pos_ == "ADP" and dep.dep_ == "prep":
            objs.extend([tok for tok in dep.rights if tok.dep_  in OBJECTS or (tok.pos_ == "PRON" and tok.lower_ == "me")])
    return objs

def getObjsFromAttrs(deps):
    for dep in deps:
        if dep.pos_ == "NOUN" and dep.dep_ == "attr":
            verbs = [tok for tok in dep.rights if tok.pos_ == "VERB"]
            if len(verbs) > 0:
                for v in verbs:
                    rights = list(v.rights)
                    objs = [tok for tok in rights if tok.dep_ in OBJECTS]
                    objs.extend(getObjsFromPrepositions(rights))
                    if len(objs) > 0:
                        return v, objs
    return None, None

def getObjFromXComp(deps):
    for dep in deps:
        if dep.pos_ == "VERB" and dep.dep_ == "xcomp":
            v = dep
            rights = list(v.rights)
            objs = [tok for tok in rights if tok.dep_ in OBJECTS]
            objs.extend(getObjsFromPrepositions(rights))
            if len(objs) > 0:
                return v, objs
    return None, None

def getAllSubs(v):
    verbNegated = isNegated(v)
    subs = [tok for tok in v.lefts if tok.dep_ in SUBJECTS and tok.pos_ != "DET" and tok.lower_ not in REL_PRONS]
    if len(subs) > 0:
        subs.extend(getSubsFromConjunctions(subs))
    else:
        foundSubs, verbNegated = findSubs(v)
        subs.extend(foundSubs)
    return subs, verbNegated

def getAllObjs(v):
    # rights is a generator
    rights = list(v.rights)
    objs = [tok for tok in rights if tok.dep_ in OBJECTS and tok.lower_ not in REL_PRONS]
    objs.extend(getObjsFromPrepositions(rights))

    potentialNewVerb, potentialNewObjs = getObjFromXComp(rights)
    if potentialNewVerb is not None and potentialNewObjs is not None and len(potentialNewObjs) > 0:
        objs.extend(potentialNewObjs)
        v = potentialNewVerb
    if len(objs) > 0:
        objs.extend(getObjsFromConjunctions(objs))
    return v, objs

findSVOs() does the following things:
1. find the subject, predicate, and object of an input sentence
2. if the predicate contains negation, we replace it with its antonym (if no antonym found, we report no SVO)
3. return the subject, predicate, and object, all in lemma forms

In [None]:
def findSVOs(tokens):
    svos = []
    verbs = [tok for tok in tokens if tok.pos_ == "VERB" and tok.dep_ != "aux"]
    for v in verbs:
        subs, verbNegated = getAllSubs(v)
        if len(subs) > 0:
            v, objs = getAllObjs(v)
            for sub in subs:
                for obj in objs:
                    objNegated = isNegated(obj)
                    if verbNegated or objNegated: # if negative word, get the antonym
                        neg_v = None
                        found = False
                        for syn in wordnet.synsets(v.lower_):
                            if found:
                                break
                            if syn.pos() != 'v':
                                continue
                            for l in syn.lemmas():
                                if l.antonyms():
                                    neg_v = l.antonyms()[0].name()
                                    found = True
                        if neg_v != None:
                            s_lemma = sub.lemma_ if sub.lemma_ != '-PRON-' else sub.lower_ # spacy's pronoun lemma hack
                            o_lemma = obj.lemma_ if obj.lemma_ != '-PRON-' else obj.lower_
                            svos.append((s_lemma, neg_v, o_lemma))
                    else:
                        s_lemma = sub.lemma_ if sub.lemma_ != '-PRON-' else sub.lower_
                        v_lemma = v.lemma_
                        o_lemma = obj.lemma_ if obj.lemma_ != '-PRON-' else obj.lower_
                        svos.append((s_lemma, v_lemma, o_lemma))
    return svos

findGenerals() does the following things:
1. find all nouns (including proper nouns) in the input sentence
2. find all verbs, except auxiliary verbs, in the input sentence
3. find all adjectives in the input sentence
4. find all adverbs in the input sentence

In [None]:
def findGenerals(tokens):
    nouns = []
    verbs = []
    adjs = []
    advs = []
    all_words = []
    for tok in tokens:
        t_lemma = tok.lemma_ if tok.lemma_ != '-PRON-' else tok.lower_
        all_words.append(t_lemma)
        if tok.pos_ == "NOUN" or tok.pos_ == "PROPN":
            nouns.append(t_lemma)
        elif tok.pos_ == "VERB" and tok.dep_ != "aux":
            verbs.append(t_lemma)
        elif tok.pos_ == "ADJ":
            adjs.append(t_lemma)
        elif tok.pos_ == "ADV":
            advs.append(t_lemma)
    return nouns, verbs, adjs, advs, all_words

find_related_entities() takes in the whole news article and determine a list of entity pairs that appear in the article.

Note that here we need both entities in the entity pair to appear in the article.

In [None]:
def find_related_entities(line, interest_pair_list):
    related_interest_pairs = []
    for interest_pair in interest_pair_list:
        e1_appears = False
        e2_appears = False
        for e1 in entity_aliases_dict[interest_pair[0]]:
            if e1 in line:
                e1_appears = True
        for e2 in entity_aliases_dict[interest_pair[1]]:
            if e2 in line:
                e2_appears = True
        if e1_appears and e2_appears:
            related_interest_pairs.append(interest_pair)
    return related_interest_pairs

process_news_article() is the main function for processing every news article.

Before introducing how this function works, we should first look at a typical news article in the NOW corpus:

"@@12345 <h\> HEADER OF THE NEWS <p\> First paragraph's text <p\> Second paragraph's text <p\> ...", where "12345" is the id of the news article. The larger id, the later crawled.

process_news_article() does the following things:
1. split a news article to news id, header, and paragraphs
2. for the header and each paragraph, we first check if it contains any entity pair (need both entities within the pair to appear)
3. if true, we segment the paragraph into sentences and check every sentence if it contains any entity pair (also need both entities to appear)
4. if true, we find the SVOs, nouns, verbs, adjectives, and adverbs in the sentence
5. finally we aggregate all the SVOs, nouns, verbs, adjectives, and adverbs found in the article and return them

In [None]:
def process_news_article(line, interest_pair_list):
    split_list = re.split(r'<.>', line) # split to paragraphs using <h> and <p>
    
    news_svo_info = defaultdict(list)
    news_nn_info = defaultdict(list)
    news_vb_info = defaultdict(list)
    news_jj_info = defaultdict(list)
    news_rb_info = defaultdict(list)
    news_all_info = defaultdict(list)
    news_samples = defaultdict(list)
    
    if len(split_list) <= 1:
        return 0, news_svo_info, news_nn_info, news_vb_info, news_jj_info, news_rb_info, news_all_info, news_samples
    
    try:
        news_index = int(split_list[0][2:])
    except:
        return 0, news_svo_info, news_nn_info, news_vb_info, news_jj_info, news_rb_info, news_all_info, news_samples
    
    for split in split_list[1:]:
        related_interest_pairs = []
        for interest_pair in interest_pair_list:
            e1_appears = False
            e2_appears = False
            for e1 in entity_aliases_dict[interest_pair[0]]:
                if e1 in split:
                    e1_appears = True
            for e2 in entity_aliases_dict[interest_pair[1]]:
                if e2 in split:
                    e2_appears = True
            if e1_appears and e2_appears: # if paragraph contains both entities
                related_interest_pairs.append(interest_pair)
            
        if len(related_interest_pairs) == 0: # paragraph not containing any entity pair
            continue
            
        doc = nlp(split)
        for sent in doc.sents:
            if '@ @' in sent.text: # broken sentence
                continue
                
            is_target_sent = False
            sent_related_interest_pairs = []
            for interest_pair in related_interest_pairs:
                e1_appears = False
                e2_appears = False
                for e1 in entity_aliases_dict[interest_pair[0]]:
                    if e1 in sent.text:
                        e1_appears = True
                for e2 in entity_aliases_dict[interest_pair[1]]:
                    if e2 in sent.text:
                        e2_appears = True
                if e1_appears and e2_appears: # if sentence has both entities
                    is_target_sent = True
                    sent_related_interest_pairs.append(interest_pair)
                    
            if is_target_sent: # sentence of interest
                for rip in sent_related_interest_pairs:
                    news_svo_info[rip].extend(findSVOs(sent))
                    nouns, verbs, adjs, advs, all_words = findGenerals(sent)
                    news_nn_info[rip].extend(nouns)
                    news_vb_info[rip].extend(verbs)
                    news_jj_info[rip].extend(adjs)
                    news_rb_info[rip].extend(advs)
                    news_all_info[rip].extend(all_words)
                    news_samples[rip].append(sent.text)
                    
    return news_index, news_svo_info, news_nn_info, news_vb_info, news_jj_info, news_rb_info,\
        news_all_info, news_samples

Run process_news_article() on all articles in our dataset.

In [None]:
def process_news_file(filename):
    f = open(filename, 'r')

    if 'text_' in filename:
        time = filename[len(load_path) + 5: len(load_path) + 10]
        year = time[:2]
        month = time[-2:]
        time = (int(year) - year_base) * 12 + int(month)
        short_filename = filename[len(load_path) + 5: -4]
    else:
        time = filename[len(load_path): len(load_path) + 5]
        year = time[:2]
        month = time[-2:]
        time = (int(year) - year_base) * 12 + int(month)
        short_filename = filename[len(load_path): -4]

    news_list = f.readlines()
    entity_svo_dict = defaultdict(dict)
    entity_nn_dict = defaultdict(dict)
    entity_vb_dict = defaultdict(dict)
    entity_jj_dict = defaultdict(dict)
    entity_rb_dict = defaultdict(dict)
    entity_all_dict = defaultdict(dict)
    entity_sample_dict = defaultdict(dict)

    for news in news_list:
        related_pairs = find_related_entities(news, interest_pair_list)
        if len(related_pairs) > 0:
            news_index, news_svo_info, news_nn_info, news_vb_info, news_jj_info, news_rb_info,\
                news_all_info, news_samples = process_news_article(news, related_pairs)
            for k,v in news_svo_info.items():
                entity_svo_dict[k][(time, news_index)] = v
            for k,v in news_nn_info.items():
                entity_nn_dict[k][(time, news_index)] = v
            for k,v in news_vb_info.items():
                entity_vb_dict[k][(time, news_index)] = v
            for k,v in news_jj_info.items():
                entity_jj_dict[k][(time, news_index)] = v
            for k,v in news_rb_info.items():
                entity_rb_dict[k][(time, news_index)] = v
            for k,v in news_all_info.items():
                entity_all_dict[k][(time, news_index)] = v
            for k,v in news_samples.items():
                entity_sample_dict[k][(time, news_index)] = '\n'.join(v)
    
    f.close()
    pickle.dump(entity_svo_dict, open(write_path + 'large_svo_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_nn_dict, open(write_path + 'large_nn_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_vb_dict, open(write_path + 'large_vb_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_jj_dict, open(write_path + 'large_jj_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_rb_dict, open(write_path + 'large_rb_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_all_dict, open(write_path + 'large_all_' + short_filename + '.pkl', 'wb'))
    pickle.dump(entity_sample_dict, open(write_path + 'large_sample_' + short_filename + '.pkl', 'wb'))
    print(short_filename, 'finished')

In [None]:
%%capture process_time
%%time

filename_list = []
for filter_str in filter_strs:
    for filename in glob.glob(load_path + filter_str + '.txt'):
        filename_list.append(filename)

with Pool(40) as p: # fork 40 processes
    p.map(process_news_file, filename_list)

In [None]:
process_time.show()

# Generate model input

In [None]:
import spacy
import pickle
from collections import defaultdict
from collections import Counter
import numpy as np
import glob

Load the entity pairs and the lexical data of all articles processed in the previous section.

In [None]:
load_path = '/data/ent2ent/han/incremental/'
write_path = '/data/ent2ent/han/incremental/'
meta_field_list = ['Internation']

entity_aliases_dict = {'U.S.': {'U.S.', 'US', 'USA', 'Trump', 'Obama'},
                       'China': {'China', 'Chinese', 'Xi'},
                       'Syria': {'Syria', 'Syrian', 'Assad'},
                       'France': {'France', 'French', 'Macron', 'Hollande'},
                       'Germany': {'Germany', 'German', 'Merkel'},
                       'Canada': {'Canada', 'Canadian', 'Trudeau'},
                       'Russia': {'Russia', 'Russian', 'Putin'},
                       'India': {'India', 'Indian', 'Modi'},
                       'U.K.': {'U.K.', 'UK', 'British', 'Britain', 'Cameron'},
                       'Japan': {'Japan', 'Japanese', 'Abe'},
                       'Iran': {'Iran', 'Iranian', 'Khamenei', 'Rouhani'},
                       'Israel': {'Israel', 'Israeli', 'Netanyahu'}}
entities_list = list(entity_aliases_dict.keys())
entity_pair_list = []
for i in range(len(entities_list) - 1):
    for j in range(i + 1, len(entities_list)):
        entity_pair_list.append((entities_list[i], entities_list[j]))

sm_nlp = spacy.load('en_core_web_sm') # due to spacy's bug, only small version handles stop words
lg_nlp = spacy.load('en_core_web_lg') # large version handles word embedding vectors

In [None]:
entity_svo_dict = defaultdict(dict)
entity_nn_dict = defaultdict(dict)
entity_vb_dict = defaultdict(dict)
entity_jj_dict = defaultdict(dict)
entity_rb_dict = defaultdict(dict)
entity_all_dict = defaultdict(dict)
entity_sample_dict = defaultdict(dict)

for filename in glob.glob(load_path + 'large_svo*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_svo_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_nn*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_nn_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_vb*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_vb_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_jj*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_jj_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_rb*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_rb_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_all*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_all_dict[ep].update(ed)
for filename in glob.glob(load_path + 'large_sample*.pkl'):
    d = pickle.load(open(filename, 'rb'))
    for ep, ed in d.items():
        entity_sample_dict[ep].update(ed)

The below blocks generate the input for our entity relationship model.

The input mainly contains the following things:
1. "span_data" that contains trajectories of entity pair's lexical data
2. "target_verbs_ix_set" that limits the word choice of our model's relationship descriptors
3. "We" that saves trained GloVe word embeddings for all words appeared in the dataset

In [None]:
margin_words = set()

# preprocessing for RMN
upper_bound = 500
lower_bound = 5000

all_words_counter = Counter()
for ep in entity_all_dict.keys():
    for mp in entity_all_dict[ep].keys():
        for w in entity_all_dict[ep][mp]:
            all_words_counter[w] += 1
print(len(all_words_counter))

cnts = all_words_counter.most_common()
for w, _ in cnts[:upper_bound]:
    margin_words.add(w)
for w, _ in cnts[-lower_bound:]:
    margin_words.add(w)

In [None]:
raw_span_data = []

for ep in entity_pair_list:
    news_metadata_pair_set = list(entity_svo_dict[ep].keys())
    metadata_ordered_list = sorted(news_metadata_pair_set, key=lambda x: x[1]) # news index reflects time order
    raw_spans = []
    raw_masks = []
    raw_months = []
    raw_samples = []
    for i, mp in enumerate(metadata_ordered_list):
        raw_spans_buf = []
        raw_masks_buf = []
        raw_samples_buf = []
        
        for svo in entity_svo_dict[ep][mp]:
            if not sm_nlp.vocab[svo[0]].is_stop:
                raw_spans_buf.append(svo[0])
                raw_masks_buf.append(1) # s -> 1
            if not sm_nlp.vocab[svo[1]].is_stop:
                raw_spans_buf.append(svo[1])
                raw_masks_buf.append(2) # v -> 2
            if not sm_nlp.vocab[svo[2]].is_stop:
                raw_spans_buf.append(svo[2])
                raw_masks_buf.append(3) # o -> 3
                
        for nn in entity_nn_dict[ep][mp]:
            if not sm_nlp.vocab[nn].is_stop:
                raw_spans_buf.append(nn)
                raw_masks_buf.append(4) # nn -> 4
        for vb in entity_vb_dict[ep][mp]:
            if not sm_nlp.vocab[vb].is_stop:
                raw_spans_buf.append(vb)
                raw_masks_buf.append(5) # vb -> 5
        for jj in entity_jj_dict[ep][mp]:
            if not sm_nlp.vocab[jj].is_stop:
                raw_spans_buf.append(jj)
                raw_masks_buf.append(6) # jj -> 6
        for rb in entity_rb_dict[ep][mp]:
            if not sm_nlp.vocab[rb].is_stop:
                raw_spans_buf.append(rb)
                raw_masks_buf.append(7) # rb -> 7
                
        for w in entity_all_dict[ep][mp]:
            if w not in margin_words:
                raw_spans_buf.append(w)
                raw_masks_buf.append(8) # processed general words -> 8
                
        raw_samples_buf.append(entity_sample_dict[ep][mp])
        
        if (2 in raw_masks_buf): # if no predicate in the group, discard
            raw_spans.append(raw_spans_buf)
            raw_masks.append(raw_masks_buf)
            raw_months.append(mp[0])
            raw_samples.append(raw_samples_buf)
        
    raw_span_data.append(([meta_field_list[0]], [ep[0], ep[1]], raw_spans, raw_masks, raw_months, raw_samples))

In [None]:
meta_field_set = set()
entity_set = set()
word_set = set()
max_len = 0

for rsd in raw_span_data:
    for mf in rsd[0]:
        meta_field_set.add(mf)
    for ent in rsd[1]:
        entity_set.add(ent)
    for rs in rsd[2]:
        if len(rs) > max_len:
            max_len = len(rs)
        for w in rs:
            word_set.add(w)

mf2ix = {word: i for i, word in enumerate(meta_field_set)}
e2ix = {word: i for i, word in enumerate(entity_set)}
w2ix = {word: i for i, word in enumerate(word_set)}
ix2mf = {i: word for word, i in mf2ix.items()}
ix2e = {i: word for word, i in e2ix.items()}
ix2w = {i: word for word, i in w2ix.items()}

In [None]:
span_data = []
target_verb_ix_set = set()

predicate_ix_counter = Counter()
for rsd in raw_span_data:
    mf_list = [mf2ix[mf] for mf in rsd[0]]
    ent_list = [e2ix[ent] for ent in rsd[1]]
    spans = np.zeros([len(rsd[2]), max_len], dtype=int)
    masks = np.zeros([len(rsd[2]), max_len], dtype=float)
    months = np.array(rsd[4], dtype=int)
    samples = rsd[5]
    for i, rs in enumerate(rsd[2]):
        for j, w in enumerate(rs):
            spans[i][j] = w2ix[w]
            masks[i][j] = rsd[3][i][j]
            if masks[i][j] == 2: # all predicates
                predicate_ix_counter[spans[i][j]] += 1
    span_data.append((mf_list, ent_list, spans, masks, months, samples))

for wix, _ in predicate_ix_counter.most_common()[:500]:
    target_verb_ix_set.add(wix)

In [None]:
print(len(predicate_ix_counter))

In [None]:
We = np.zeros([len(ix2w), 300], dtype=float)

for i, word in ix2w.items():
    We[i] = lg_nlp.vocab[word].vector
    
norms = np.linalg.norm(We, axis=1)

for i, emb in enumerate(We):
    if norms[i] > 0:
        for j in range(len(emb)):
            We[i][j] /= norms[i]

In [None]:
model_input = (ix2mf, ix2e, w2ix, ix2w, span_data, max_len, target_verb_ix_set, We)

In [None]:
pickle.dump(model_input, open(write_path + 'large_model_input_gs1.pkl', 'wb'))

### Distribution of entity pair occurrence

In [None]:
model_input = pickle.load(open(write_path + 'large_model_input_gs1.pkl', 'rb'))

In [None]:
ix2mf, ix2e, w2ix, ix2w, span_data, max_len, target_verb_ix_set, We = model_input

In [None]:
# article
ent_dist = sorted([(ix2e[e_list[0]], ix2e[e_list[1]], len(span)) for _, e_list, span, _,_,_ in span_data],
                  key=lambda x: -x[2])
num_nation_pair_related_articles = 0
for ed in ent_dist:
    num_nation_pair_related_articles += ed[2]
print(num_nation_pair_related_articles)

In [None]:
num_nation_pair_related_sentences = 0
for _, e_list, span, _, _, article_samples in span_data:
    for article_sample in article_samples:
        sentence_split = article_sample[0].split('\n') # index 0 since we have one extra list out article_sample
        num_nation_pair_related_sentences += len(sentence_split)
print(num_nation_pair_related_sentences)

In [None]:
ent_dist = sorted([(ix2e[e_list[0]], ix2e[e_list[1]], len(span)) for _, e_list, span, _,_,_ in span_data],
                  key=lambda x: -x[2])
for ed in ent_dist:
    print(ed)

### Average number of predicates per span

In [None]:
avg_list = []
for _, _, _, masks, _, _ in span_data:
    for mask in masks:
        cnt = 0
        for m in mask:
            if m == 2:
                cnt += 1
        avg_list.append(cnt)
print(1.0 * sum(avg_list) / len(avg_list))

### Average number of nouns per span

In [None]:
avg_list = []
for _, _, _, masks, _, _ in span_data:
    for mask in masks:
        cnt = 0
        for m in mask:
            if m == 4: # nouns
                cnt += 1
        avg_list.append(cnt)
print(1.0 * sum(avg_list) / len(avg_list))