In [26]:
from glob import glob
import json
from collections import Counter
import pickle
import numpy as np
import random
from tqdm import tqdm

import spacy
#spacy.require_gpu()
import en_core_web_sm
nlp = en_core_web_sm.load()
import neuralcoref
neuralcoref.add_to_pipe(nlp)

from spacy import displacy
import itertools

import textacy
from textacy.extract import subject_verb_object_triples

In [35]:
def extract_events(text, num_candidates=5, min_chain=3):
    doc = nlp(text)
    if len(doc._.coref_clusters) == 0:
        raise ValueError("Failed to find any coref clusters")
        return None
    svo_triples = list(subject_verb_object_triples(doc))
    if len(svo_triples) == 0:
        raise ValueError("No subject-verb-object triples found")
    coref_subj = [(svo[0][0]._.coref_clusters[0].main, svo[1], svo[2])
                  for svo in svo_triples 
                  if svo[0][0]._.in_coref]
    coref_obj = [(svo[0], svo[1], svo[2][0]._.coref_clusters[0].main)
                  for svo in svo_triples 
                  if svo[2][0]._.in_coref]
    cluster_counter = Counter(itertools.chain((triple[0] for triple in coref_subj), (triple[2] for triple in coref_obj)))
    if len(cluster_counter) == 0:
        raise ValueError("No coreference clusters found")
    chosen_cluster = cluster_counter.most_common(1)[0][0]
    subj_triples = [(None, triple[1], triple[2])  for triple in coref_subj if triple[0] == chosen_cluster]
    obj_triples = [(triple[0], triple[1], None) for triple in coref_obj if triple[2] == chosen_cluster]
    triples = sorted(itertools.chain(subj_triples, obj_triples), key=lambda x: x[1][0].i)
    
    sent_starts = [triple[1][0].sent.start for triple in triples]
    first_in_sents_i = np.unique(sent_starts, return_index=True)[1]
    triples = [triples[i] for i in first_in_sents_i]
    
    if not triples:
        raise ValueError("No svo triples found for coreferenced entities")
    #Extract entities
    pattern = r'<VERB>?<ADV>*<VERB>+'
    entities = list(set([str(chunk) for chunk in doc.noun_chunks if not chunk[0]._.in_coref]))
    verb_phrases = list(set((str(vf) for vf in textacy.extract.pos_regex_matches(doc, pattern))))
    
    if len(verb_phrases) < num_candidates or len(entities) < num_candidates:
        raise ValueError("Not enough coreferenced verb phrases or entities found")
    
    triple_tuples = [tuple(map(lambda x: x if x is None else str(x), triple)) for triple in triples]
    final_phrase = triple_tuples[-1]
    triple_tuples = triple_tuples[:-1]
    
    if len(triple_tuples) < min_chain:
        raise ValueError("Chain of length %d not long enough" %len(triple_tuples))
    
    verb_choices = [verb_phrases[i] for i in np.random.randint(0, len(verb_phrases), size=num_candidates - 1)]
    entity_choices = [entities[i] for i in np.random.randint(0, len(entities), size=num_candidates - 1)]
    candidate_choices = [(e, v, None) if sv else (None, v, e)
                         for v, e, sv in 
                         zip(verb_choices, 
                             entity_choices, 
                             np.random.choice([True, False], size=num_candidates-1))
                        ]
    candidate_choices.append(final_phrase)
    shuffle = list(range(len(candidate_choices)))
    random.shuffle(shuffle)
    candidate_choices = [candidate_choices[i] for i in shuffle]
    
    return {'entity':str(chosen_cluster), 
            'triples':triple_tuples,
            'sentences': [str(triple[1][0].sent) for triple in triples[:-1]],
            'final_sentence': str(triples[-1][1][0].sent),
            'candidates': candidate_choices,
            'correct': shuffle.index(num_candidates - 1)
           }

## 1.zip extractions

In [36]:
random.seed('triple_extraction')
out_file = "dataset/extractions.pickle"
with open(out_file, "wb") as pickle_f:
    num_saved = 0
    with tqdm(glob("dataset/1/*.ta.xml"), bar_format="{l_bar}{bar}{r_bar} {postfix[0]} saved", postfix=[0]) as ext_prog:
        for f_path in ext_prog:
            with open(f_path) as f:
                try:
                    parse_dict = json.load(f)
                    extraction = extract_events(parse_dict['text'])
                    #print(extraction['triples'])
                    #print(extraction['sentences'])
                    pickle.dump(extraction, pickle_f)
                    num_saved += 1
                    ext_prog.postfix[0] += 1
                except ValueError as e:
                    #print("ERROR", e)
                    continue
        print("Saved %d event chains" %num_saved)

100%|██████████| 946/946 [08:05<00:00,  1.05it/s[185]] 185 saved

Saved 185 event chains





## Gigaword Extractions

In [None]:
random.seed('triple_extraction_gw')
out_file = "gw_extractions.pickle"
with open(out_file, "wb") as pickle_f:
    num_saved = 0
    num_docs = len(glob('temp_g_extract/gigaword-nyt/input/text/nyt_eng_*/*'))
    for i, f_path in enumerate(glob('temp_g_extract/gigaword-nyt/input/text/nyt_eng_*/*')):
        if i % 10 == 0:
            print('%d/%d' %(i, num_docs))
        with open(f_path) as f:
            try:
                extraction = extract_events(f.read())
                #print(extraction['entity'])
                pickle.dump(extraction, pickle_f)
                num_saved += 1
            except ValueError as e:
                #print("ERROR", e)
                continue
    print("Saved %d event chains")

0/1043814
10/1043814
20/1043814
30/1043814
40/1043814
50/1043814
60/1043814
70/1043814
80/1043814
90/1043814
100/1043814
110/1043814
120/1043814
130/1043814
140/1043814
150/1043814
160/1043814
170/1043814
180/1043814
190/1043814
200/1043814
210/1043814
220/1043814
230/1043814
240/1043814
250/1043814
260/1043814
270/1043814
280/1043814
290/1043814
300/1043814
310/1043814
320/1043814
330/1043814
340/1043814
350/1043814
360/1043814
370/1043814
380/1043814
390/1043814
400/1043814
410/1043814
420/1043814
430/1043814
440/1043814
450/1043814
460/1043814
470/1043814
480/1043814
490/1043814
500/1043814
510/1043814
520/1043814
530/1043814
540/1043814
550/1043814
560/1043814
570/1043814
580/1043814
590/1043814
600/1043814
610/1043814
620/1043814
630/1043814
640/1043814
650/1043814
660/1043814
670/1043814
680/1043814
690/1043814
700/1043814
710/1043814
720/1043814
730/1043814
740/1043814
750/1043814
760/1043814
770/1043814
780/1043814
790/1043814
800/1043814
810/1043814
820/1043814
830/1043814
840

6390/1043814
6400/1043814
6410/1043814
6420/1043814
6430/1043814
6440/1043814
6450/1043814
6460/1043814
6470/1043814
6480/1043814
6490/1043814
6500/1043814
6510/1043814
6520/1043814
6530/1043814
6540/1043814
6550/1043814
6560/1043814
6570/1043814
6580/1043814
6590/1043814
6600/1043814
6610/1043814
6620/1043814
6630/1043814
6640/1043814
6650/1043814
6660/1043814
6670/1043814
6680/1043814
6690/1043814
6700/1043814
6710/1043814
6720/1043814
6730/1043814
6740/1043814
6750/1043814
6760/1043814
6770/1043814
6780/1043814
6790/1043814
6800/1043814
6810/1043814
6820/1043814
6830/1043814
6840/1043814
6850/1043814
6860/1043814
6870/1043814
6880/1043814
6890/1043814
6900/1043814
6910/1043814
6920/1043814
6930/1043814
6940/1043814
6950/1043814
6960/1043814
6970/1043814
6980/1043814
6990/1043814
7000/1043814
7010/1043814
7020/1043814
7030/1043814
7040/1043814
7050/1043814
7060/1043814
7070/1043814
7080/1043814
7090/1043814
7100/1043814
7110/1043814
7120/1043814
7130/1043814
7140/1043814
7150/1043814

In [13]:
def read_data_iterator(pickle_file):
    with open(pickle_file, 'rb') as f:
        for i in itertools.count():
            try:
                yield pickle.load(f)
            except EOFError:
                raise StopIteration
print(list(read_data_iterator('extractions.pickle')))

[{'entity': 'Mr. Seignious', 'correct': 4, 'triples': [(None, 'was wearing', 'pants'), ('who', 'confronted', None), ('officers', 'chased', None), ('officers', 'ordered', None), (None, 'took', 'step'), (None, 'was wearing', 'police badge'), (None, 'might have been impersonating', 'officer')], 'candidates': [(None, 'was carrying was', 'identification'), ('that time', 'was fired', None), (None, 'lived', 'a police badge'), (None, 'ordered', "a friend's apartment application"), ('Correction Department', 'placed', None)], 'sentences': ['He was wearing a turtleneck and cargo pants that resembled a police uniform.\n', 'The security officer flagged down a passing traffic officer, who confronted Mr. Seignious and saw what seemed to be a bulge from a weapon in his pants, the police said.', 'When Mr. Seignious emerged from Sears, the five officers chased him a few blocks to East 187th Street and Webster Avenue, where he drew what appeared to be a handgun, the police said.', 'The officers ordered h