In [1]:
import json

In [42]:
train = json.load(open('./tacred/data/json/train.json','r',encoding='utf-8'))
dev = json.load(open('./tacred/data/json/dev.json','r',encoding='utf-8'))
test = json.load(open('./tacred/data/json/test.json','r',encoding='utf-8'))

In [40]:
train

<_io.TextIOWrapper name='./tacred/data/tsv_cased/train.tsv' mode='r' encoding='utf-8'>

In [23]:
dataset = test

In [24]:
def get_entities(sample):
    if sample['obj_start'] < sample['subj_start']:
        e1 = ' '.join(sample['token'][sample['obj_start']:sample['obj_end']+1])
        e2 = ' '.join(sample['token'][sample['subj_start']:sample['subj_end']+1])
    else:
        e2 = ' '.join(sample['token'][sample['obj_start']:sample['obj_end']+1])
        e1 = ' '.join(sample['token'][sample['subj_start']:sample['subj_end']+1])
    return (e1,e2)

def get_entity_tokens(sample):
    if sample['obj_start'] < sample['subj_start']:
        e1 = sample['token'][sample['obj_start']:sample['obj_end']+1]
        e2 =sample['token'][sample['subj_start']:sample['subj_end']+1]
    else:
        e2 = sample['token'][sample['obj_start']:sample['obj_end']+1]
        e1 = sample['token'][sample['subj_start']:sample['subj_end']+1]
    return (e1,e2)

In [25]:
entities = [get_entities(x) for x in dataset]
entities_tokenized = [get_entity_tokens(x) for x in dataset]
entities_tokens_flat = [x+y for x,y in entities_tokenized]

In [26]:
dep_parser_res = [list(zip(range(1,len(sample['token'])+1),
                          sample['token'],
                          sample['stanford_head'],
                          sample['stanford_deprel'])) for sample in dataset]

In [27]:
dep_parser_token_sets = [set(k[1] for k in x) for x in dep_parser_res]

In [28]:
exceptional_cases = list(filter(lambda x: any(map(lambda k: k not in dep_parser_token_sets[x[0]],
                              x[1])), 
                          enumerate(entities_tokens_flat)))

In [29]:
exceptional_case_ids = set(x[0] for x in exceptional_cases)

In [30]:
def generate_paths_to_root(entity_tokens,
                          term_ids_map,
                          id_dep_id_map):
    paths = []
    for term in entity_tokens:
        t_ids = term_ids_map[term]
        for t_id in t_ids:
            temp = t_id
            temp_path = []
            while temp != 0:
                temp_path.append(temp)
                temp = id_dep_id_map[temp]
            paths.append(temp_path)
    return paths

def find_shortest_path_between_entities(e1_paths,
                                        e2_paths):
    paths = []
    e1_path_sets = list(map(set,e1_paths))
    e2_path_sets = list(map(set,e2_paths))
    #print(e1_paths)
    #print(e2_paths)
    
    for i1, p1 in enumerate(e1_paths):
        for i2, p2 in enumerate(e2_paths):
            common_ids = e1_path_sets[i1] & e2_path_sets[i2]
            #print('common', common_ids)
            temp_path = []
            for t_id in p1:
                temp_path.append(t_id)
                if t_id in common_ids:
                    break
            #print('p1 end')
            #print(temp_path)
            r_p2 = list(reversed(p2))
            for index,t_id in enumerate(r_p2):
                if t_id not in common_ids:
                    temp_path.extend(r_p2[index:])
                    break
            paths.append(temp_path)
    
    return min(paths,key=len)

In [31]:
def calculate_sdp(sample_index,sample,exceptional_case_ids):
    res = ['Error']
    
    
    id_term_map = {int(x[0]):x[1] for x in sample}
    term_ids_map = {}
    for x in sample:
        term_ids_map[x[1]] = term_ids_map.get(x[1],[])
        term_ids_map[x[1]].append(int(x[0]))

    id_dep_id_map = {int(x[0]):x[2] for x in sample}
    temp_entities = entities_tokenized[sample_index]

    if sample_index in exceptional_case_ids:
        for i,k in enumerate(sample):
            for ent in temp_entities:
                for ind,e in enumerate(ent):
                    if e.startswith(k[1]) and \
                    (i+1 != len(sample)) and \
                     e.endswith(sample[i+1][1]):
                        ent[ind] = k[1]
                        
    e1 = entities_tokenized[sample_index][0]
    #print(e1)
    e2 = entities_tokenized[sample_index][1]
    #print(e2)
    try:
        e1_paths = generate_paths_to_root(e1,
                                      term_ids_map,
                                      id_dep_id_map)
    except:
        print(sample, e1, term_ids_map, id_dep_id_map)
        raise
    e2_paths = generate_paths_to_root(e2,
                                      term_ids_map,
                                      id_dep_id_map)

    shortest_path =find_shortest_path_between_entities(e1_paths,
                                                       e2_paths)
    res = list(id_term_map[x] for x in shortest_path)

    return res

In [32]:
res = [calculate_sdp(index,sample,exceptional_case_ids) for index,sample in enumerate(dep_parser_res)]

In [33]:
res

[['of', 'department'],
 ['his', 'power', 'impose', 'Palermo'],
 ['Vagni', 'Notter', 'Switzerland'],
 ['he', 'calls', 'repository', 'journalists', 'Western'],
 ['Association', 'percent', 'named', 'destination', 'PATA', 'Bangkok-based'],
 ['Helen', 'organization', 'won', 'foundation'],
 ['ADF', 'said', 'responding', 'group', 'Timor'],
 ['He', 'had', 'him'],
 ['missionary', 'Silsby'],
 ['Television', 'AIA', 'Richard'],
 ['Lange', 'met', 'Berkeley', 'he'],
 ['journalist',
  'owner',
  'Lomax',
  'story',
  'shares',
  'taking',
  'family',
  'her'],
 ['Katrina', 'name', 'Ramon'],
 ['Ramon', 'promoted', 'lieutenant', 'captain'],
 ['her', 'husband', 'maiming', 'friend', 'Nash'],
 ['he', 'received', 'credentials', 'his'],
 ['Cain', 'tenure', 'NRA'],
 ['Miami',
  'BC-FLA-SCIENTOLOGY-DENTIST',
  'skilled',
  'selling',
  'services',
  'his'],
 ['http://www.adb.org/media/Articles/2007/12155-asian-poverties-reductions',
  'becomes',
  '*',
  'http://www.adb.org/ADF',
  '>',
  'ADF'],
 ['Cain', 'r

### Additional part

In [34]:
for i,r in enumerate(res):
    if r[0] in entities_tokenized[i][0]:
        r[0] = entities[i][0]
        r[-1] = entities[i][-1]
    else:
        print(i)

### Save

In [36]:
import pickle

In [37]:
pickle.dump(res,open('./stanford_tacred_test_sdp_results.pkl','wb'))