In [None]:
hpsg_dependency_parsing_results = './results/HPSG_neural_parser/test/'
default_semeval_tsv_data_path = './R-BERT/data/train.tsv'


In [2]:
import os
import re
import pickle

def return_entities(text):
    e1 = re.match('.*\[E11\](.*)\[E12\]',text).groups()[0].strip()
    e2 = re.match('.*\[E21\](.*)\[E22\]',text).groups()[0].strip()
    return (e1,e2)


def generate_paths_to_root(entity_tokens,
                          term_ids_map,
                          id_dep_id_map):
    paths = []
    for term in entity_tokens:
        t_ids = term_ids_map[term]
        for t_id in t_ids:
            temp = t_id
            temp_path = []
            while temp != 0:
                temp_path.append(temp)
                temp = id_dep_id_map[temp]
            paths.append(temp_path)
    return paths


def find_shortest_path_between_entities(e1_paths,
                                        e2_paths):
    paths = []
    e1_path_sets = list(map(set,e1_paths))
    e2_path_sets = list(map(set,e2_paths))
    #print(e1_paths)
    #print(e2_paths)
    
    for i1, p1 in enumerate(e1_paths):
        for i2, p2 in enumerate(e2_paths):
            common_ids = e1_path_sets[i1] & e2_path_sets[i2]
            #print('common', common_ids)
            temp_path = []
            for t_id in p1:
                temp_path.append(t_id)
                if t_id in common_ids:
                    break
            #print('p1 end')
            #print(temp_path)
            r_p2 = list(reversed(p2))
            for index,t_id in enumerate(r_p2):
                if t_id not in common_ids:
                    temp_path.extend(r_p2[index:])
                    break
            paths.append(temp_path)
    
    return min(paths,key=len)


def calculate_sdp(sample_index,sample,exceptional_case_ids):
    res = ['Error']
    
    
    id_term_map = {int(x[0]):x[1] for x in sample}
    term_ids_map = {}
    for x in sample:
        term_ids_map[x[1]] = term_ids_map.get(x[1],[])
        term_ids_map[x[1]].append(int(x[0]))

    id_dep_id_map = {int(x[0]):x[2] for x in sample}
    temp_entities = entities_tokenized[sample_index]

    if sample_index in exceptional_case_ids:
        for i,k in enumerate(sample):
            for ent in temp_entities:
                for ind,e in enumerate(ent):
                    if e.startswith(k[1]) and \
                    (i+1 != len(sample)) and \
                     e.endswith(sample[i+1][1]):
                        ent[ind] = k[1]
                        
    e1 = entities_tokenized[sample_index][0]
    #print(e1)
    e2 = entities_tokenized[sample_index][1]
    #print(e2)
    try:
        e1_paths = generate_paths_to_root(e1,
                                      term_ids_map,
                                      id_dep_id_map)
    except:
        print(sample, e1, term_ids_map, id_dep_id_map)
        raise
    e2_paths = generate_paths_to_root(e2,
                                      term_ids_map,
                                      id_dep_id_map)

    shortest_path =find_shortest_path_between_entities(e1_paths,
                                                       e2_paths)
    res = list(id_term_map[x] for x in shortest_path)

    return res

In [40]:
dp_lines = []
for x in os.listdir(hpsg_dependency_parsing_results):
    if x.startswith('output_synconst'):
        dp_lines.extend(open(hpsg_dependency_parsing_results + x,'r').readlines())
        
tokens_with_pos = [re.findall('([^\s\(]+) ([\w.;"\':_-]+)\)',x) for x in dp_lines]

dp_token_parents = []
for x in os.listdir(hpsg_dependency_parsing_results):
    if x.startswith('output_syndephead'):
        dp_token_parents.extend(open(hpsg_dependency_parsing_results + x,'r').readlines())
dp_token_parents = [list(map(int, x[1:-2].split(', '))) for x in dp_token_parents]

dp_token_relations = []
for x in os.listdir(hpsg_dependency_parsing_results):
    if x.startswith('output_syndeplabel'):
        dp_token_relations.extend(open(hpsg_dependency_parsing_results + x,'r').readlines())
dp_token_relations = [list(map(lambda k: k[1:-1], x[1:-2].split(', '))) for x in dp_token_relations]

dep_parser_res = [list(zip(range(1,len(tokens)+1),
                           [t[1].replace('.','') for t in tokens],
                           dp_token_parents[k],
                           dp_token_relations[k])) 
                  for k,tokens in enumerate(tokens_with_pos)]

lines = open(default_semeval_tsv_data_path,'r')
lines = lines.readlines()
example_list = list(map(lambda x: x.split('\t'), lines))

entities = [return_entities(x) for x in map(lambda k: k[1], example_list)]
entities_tokenized = [[x.split(), y.split()] for x,y in entities]
entities_tokens_flat = [x.split() + y.split() for x,y in entities]

dep_parser_token_sets = [set(k[1] for k in x) for x in dep_parser_res]

exceptional_cases = list(filter(lambda x: any(map(lambda k: k not in dep_parser_token_sets[x[0]],
                                                  x[1])), 
                          enumerate(entities_tokens_flat)))
exceptional_case_ids = set(x[0] for x in exceptional_cases)

res = [calculate_sdp(index,sample,exceptional_case_ids) for index,sample in enumerate(dep_parser_res)]


### Additional Step

for i,r in enumerate(res):
    if r[0] in entities_tokenized[i][0]:
        r[0] = entities[i][0]
        r[-1] = entities[i][-1]
    else:
        print(i)
        
pickle.dump(res,open('./hpsg_parser_train_sdp_results.pkl','wb'))

In [50]:
res

[['configuration', 'in', 'has', 'elements'],
 ['child', 'wrapped', 'bound', 'into', 'cradle'],
 ['author', 'uses', 'disassembler'],
 ['ridge', 'uprises', 'surge'],
 ['student', 'association'],
 ['complex', 'producer'],
 ['inflammation', 'caused', 'by', 'infection'],
 ['people', 'moving', 'downtown'],
 ['lawsonite', 'contained', 'platinum crucible'],
 ['solvent', 'of', 'ml', 'pipetted', 'into', 'flask'],
 ['essays', 'collected', 'in', 'volume'],
 ['composer', 'sunk', 'oblivion'],
 ['citation', 'explaining', 'reasons'],
 ['burst', 'caused', 'pressure'],
 ['networks', 'moved', 'high - definition broadcast'],
 ['call', 'remind', 'about', 'bill'],
 ['virtuoso', 'finds', 'instrument'],
 ['factory', 'products', 'included', 'pots', 'trays'],
 ['tree', 'blossom'],
 ['battalion', 'by', 'backed', 'columns', 'with', 'tried', 'grenadiers'],
 ['knowledge', 'gained', 'from', 'recruits'],
 ['stable', 'had', 'hounds'],
 ['singer', 'caused', 'commotion'],
 ['essays', 'books', 'pertinent', 'history'],
 [