In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', -1)

import numpy as np
import sklearn
from sklearn import metrics
import nltk 
from collections import Counter
import matplotlib.pyplot as plt
import itertools
import ast
import jsonlines

nltk.download('stopwords')
from nltk.corpus import stopwords 
stop_words = set(stopwords.words('english'))

# ### BOOTLEG ###
import load_entity_profiles


[nltk_data] Downloading package stopwords to
[nltk_data]     /lfs/1/simran/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /lfs/1/simran/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


loaded state dict_keys(['people_qids', 'singer_qids', 'contextual_rel_vocab', 'contextual_rel_vocab_inv', 'contextual_rel_pairs'])
Reading in vocab from /dfs/scratch0/lorr1/data_prep/data/wiki0516/entity_dump/entity_all_words/all_words_vocab.marisa
Loaded entity symbols.
Found 506 hyena types
Found 67110 wikidata types
FINISHED LOADING IN 93.61913561820984


In [2]:
from load_entity_profiles import es

In [3]:
def qid_to_mention(qid):
    try:
        return es.get_title(qid)
    except:
        return 'NULL'

In [4]:
# Bootleg utility functions:
# BY ALIAS: 
def get_candidates(alias):
    try:
        # To get qid candidates of an alias
        cands = LoadEntityProfiles.esp.get_qid_cands(alias)
        print(f"Cands {cands}")
        print([es.get_title(qid) for qid in cands])
    except:
        pass

In [5]:
LABEL_TO_ID = {'no_relation': 0, 'per:title': 1, 'org:top_members/employees': 2, 'per:employee_of': 3, 
               'org:alternate_names': 4, 'org:country_of_headquarters': 5, 'per:countries_of_residence': 6, 
               'org:city_of_headquarters': 7, 'per:cities_of_residence': 8, 'per:age': 9, 
               'per:stateorprovinces_of_residence': 10, 'per:origin': 11, 'org:subsidiaries': 12, 
               'org:parents': 13, 'per:spouse': 14, 'org:stateorprovince_of_headquarters': 15, 'per:children': 16, 
               'per:other_family': 17, 'per:alternate_names': 18, 'org:members': 19, 'per:siblings': 20, 
               'per:schools_attended': 21, 'per:parents': 22, 'per:date_of_death': 23, 'org:member_of': 24, 
               'org:founded_by': 25, 'org:website': 26, 'per:cause_of_death': 27, 
               'org:political/religious_affiliation': 28, 'org:founded': 29, 'per:city_of_death': 30, 
               'org:shareholders': 31, 'org:number_of_employees/members': 32, 'per:date_of_birth': 33, 
               'per:city_of_birth': 34, 'per:charges': 35, 'per:stateorprovince_of_death': 36, 'per:religion': 37, 
               'per:stateorprovince_of_birth': 38, 'per:country_of_birth': 39, 'org:dissolved': 40, 
               'per:country_of_death': 41}

LABEL_LST = list(LABEL_TO_ID.keys())
STANFORD_NER_TYPES = ['DATE', 'LOCATION', 'MONEY', 'ORGANIZATION', 'PERCENT', 'PERSON', 'TIME']

In [6]:
nomention = ["['UNK']", "['UNK', 'UNK']", "['UNK', 'UNK', 'UNK']", "['UNK', 'UNK', 'UNK', 'UNK']", 
             "['UNK', 'UNK', 'UNK', 'UNK', 'UNK']"]
def is_nomention(mention):
    return any(null for null in nomention if null in mention)

In [7]:
# text cleanup
def normalize_glove(tokens):
    mapping = {'-LRB-': '(',
                '-RRB-': ')',
                '-LSB-': '[',
                '-RSB-': ']',
                '-LCB-': '{',
                '-RCB-': '}'}
    for i in range(len(tokens)):
        if tokens[i] in mapping:
            #print(tokens)
            tokens[i] = mapping[tokens[i]]
    return tokens

In [8]:
def load_mentions(file): 
    lines = []
    with jsonlines.open(file) as f: 
        for line in f: 
            new_line = {
                'id': line['id'],
                'sentence': line['sentence'],
                'aliases': line['aliases'], 
                'spans': line['spans'],
                'qids': line['qids'],
                'anchor': line['anchor'],
                'sent_idx_unq': line['sent_idx_unq']
            }
            lines.append(new_line)
    return pd.DataFrame(lines)

# Alternate Names COVERAGE: FINAL DATA, POST BOOTLEG EVAL!

In [9]:
import json
base_data = '/dfs/scratch1/simran/tacred/tacred-relation-bootleg/dataset_bootleg_cidr_model/bootleg_09132020/basic_full_sentences/static_remap_embs'

train_file = "{}/train_ent.json".format(base_data)
with open(train_file) as train:
    df_train = json.load(train)
    df_train = pd.DataFrame.from_dict(df_train, orient='columns')
    print(df_train.shape)
    
dev_file = "{}/dev_rev_ent.json".format(base_data)
with open(dev_file) as dev:
    df_dev = json.load(dev)
    df_dev = pd.DataFrame.from_dict(df_dev, orient='columns')
    print(df_dev.shape)
    
test_file = "{}/test_rev_ent.json".format(base_data)
with open(test_file) as test:
    df_test = json.load(test)
    df_test = pd.DataFrame.from_dict(df_test, orient='columns')
    print(df_test.shape)

(68124, 20)
(22631, 20)
(15509, 20)


In [10]:
print(df_train.columns.values)
dfs = {'train': df_train, 'dev': df_dev, 'test':df_test}

['id' 'docid' 'relation' 'token' 'subj_start' 'subj_end' 'obj_start'
 'obj_end' 'subj_type' 'obj_type' 'stanford_pos' 'stanford_ner'
 'stanford_head' 'stanford_deprel' 'entity_emb_id' 'entity_emb_id_first'
 'ent_id' 'ent_id_first' 'static_ent_emb_id' 'static_ent_emb_id_first']


In [75]:
# ALL ALTERNATE NAME EXAMPLES

match_lists_posteval = {'train': [], 'dev': [], 'test':[]}
mis_match_lists_posteval = {'train': [], 'dev': [], 'test':[]}
for k,df in dfs.items():
    alternate_names_df = df[df['relation'].str.contains('org:alternate_names')]
    
    count_mismatch = 0
    err = 0
    matched_qids_posteval = []
    mismatched_pair_posteval = []
    
    print("{}".format(k), " has ", alternate_names_df.shape[0], " total alternate name examples.")
    with open('{}_alternate_names_all.json'.format(k), 'w') as outfile:
        for ind, row in alternate_names_df.iterrows():
            add = 0
            tokens = row['token']
            mentions = row['ent_id']

            ss, se = row['subj_start'], row['subj_end']
            subj = tokens[ss:se+1]
            subj_qid = row['ent_id'][ss:se+1]
            subj_mention = qid_to_mention(subj_qid[0])
            subj_ner = row['stanford_ner'][ss:se+1]


            os, oe = row['obj_start'], row['obj_end']
            obj = tokens[os:oe+1]
            obj_qid = row['ent_id'][os: oe+1]
            obj_mention = qid_to_mention(obj_qid[0])
            obj_ner = row['stanford_ner'][os: oe+1]


            if subj_qid[0] == 'UNK' and obj_qid[0] == 'UNK':
                err += 1
            else:
                if len(subj) == 1 or len(obj) == 1:
                    if subj_qid[0] != obj_qid[0]:
                        add = 1
                        mismatched_pair_posteval.append((subj, obj))
                    else:
                        # was a match acronym
                        matched_qids_posteval.append(subj_qid[0])        

                if add == 1:
                    # was a mismatch acronym
                    count_mismatch += 1
                
                # entry = {'acronym_mismatch': add, 'id':row['id'], 'subj':subj, 'subj_qid':subj_qid, 'subj_mention': subj_mention, 'obj':obj, 'obj_qid':obj_qid, 'obj_mention': obj_mention, 'example':tokens, 'all_mentions':mentions }
                #json.save(entry, outfile)
                #outfile.write('\n')
    # print("{}".format(k), " has ", count_mismatch, " total alternate name 'acronym' mismatch.")
    match_lists_posteval[k] = matched_qids_posteval
    mis_match_lists_posteval[k] = mismatched_pair_posteval
    # print(err)
    

train  has  808  total alternate name examples.
dev  has  348  total alternate name examples.
test  has  245  total alternate name examples.


In [76]:
# # MATCHES IN POSTEVAL BOOTLEG
# for k, v in match_lists_posteval.items():
#     print("LENGTH OF matched qids list is: ", len(v))
#     for item in v:
#         print(item,qid_to_mention(item))
#     print()

In [77]:
# # MISMATCHES IN POSTEVAL BOOTLEG
# for k, v in mis_match_lists_posteval.items():
#     print("LENGTH OF mis-matched qids list is: ", len(v))
#     for item in v:
#         print(item)
#     print()

# Alternate Names Coverage: PRE-BOOTLEG EVAL WITH CAND GENERATION

In [78]:
pre_eval_candgen_file = '/dfs/scratch1/simran/tacred/tacred-relation-bootleg/dataset_bootleg_cidr_model/bootleg_09132020/subjobj_candgen_only/candgen_subjobj_w_bootmentions.jsonl'
pre_eval_df = load_mentions(pre_eval_candgen_file)


In [79]:
# for ind, row in pre_eval_df.iterrows():
#     print(row)
#     break

In [80]:
# ALL ALTERNATE NAME EXAMPLES
match_lists_preeval_candgen = {'train': [], 'dev': [], 'test':[]}
mis_match_lists_preeval_candgen = {'train': [], 'dev': [], 'test':[]}
for k,df in dfs.items():
    alternate_names_df = df[df['relation'].str.contains('org:alternate_names')]
    split_ids = alternate_names_df.loc[:,'id'].tolist()
    
    count_mismatch = 0
    err = 0
    matched_qids_preeval_candgen = []
    mismatched_pair_preeval_candgen = []
    
    print("{}".format(k), " has ", alternate_names_df.shape[0], " total alternate name examples.")
    #with open('{}_alternate_names_candgen.json'.format(k), 'w') as outfile:
    
    for idx in split_ids:
        pre_eval_row = pre_eval_df.loc[pre_eval_df['id'] == idx]
        row = alternate_names_df.loc[alternate_names_df['id'] ==idx]
        
        add = 0
        tokens = row['token'].item()
        
        try:
            qid1 = pre_eval_row['qids'].item()[0]
            qid2 = pre_eval_row['qids'].item()[1]

            ss, se = row['subj_start'].item(), row['subj_end'].item()
            subj = tokens[ss:se+1]
            os, oe = row['obj_start'].item(), row['obj_end'].item()
            obj = tokens[os:oe+1]

            if len(subj) == 1 or len(obj) == 1:
                if qid1 != qid2:
                    add = 1
                    mismatched_pair_preeval_candgen.append((subj, obj))
                else:
                    matched_qids_preeval_candgen.append(qid1)

            if add == 1:
                count_mismatch += 1

            entry = {'acronym_mismatch': add, 'id':row['id'], 'subj':subj, 'obj':obj, 'example':tokens }
        except:
            err += 1
            pass
        
        #json.save(entry, outfile)
        #outfile.write('\n')
#     print("{}".format(k), " has ", count_mismatch, " total alternate name mismatch.")
#     print("{}".format(k), " length of matched_qid_acronyms list: ", matched_qids_preeval_candgen)
#     print("{}".format(k), " has ", err, " alternate name examples where cand gen did not extract both a subj and obj qid")
    match_lists_preeval_candgen[k] = matched_qids_preeval_candgen
    mis_match_lists_preeval_candgen[k] = mismatched_pair_preeval_candgen
#     print()
    

train  has  808  total alternate name examples.




dev  has  348  total alternate name examples.
test  has  245  total alternate name examples.


In [82]:
# # ACRONYM MATCHES IN CANDGEN 
# for k, v in match_lists_preeval_candgen.items():
#     print("LENGTH OF matched qids list is: ", len(v))
#     for item in v:
#         print(item,qid_to_mention(item))
#     print()

In [83]:
# # ACRONYM MISMATCHES IN CANDGEN
# for k, v in mis_match_lists_preeval_candgen.items():
#     print("LENGTH OF mis-matched qids list is: ", len(v))
#     for item in v:
#         print(item)
#     print()

# Alternate Names Coverage: PRE-BOOTLEG EVAL WITH REGULAR MENTION EXTRACTION

In [84]:
def get_span(span):
    span = span.split(":")
    span_s = int(span[0])
    span_e = int(span[1])
    return span_s, span_e

In [85]:
pre_eval_candgen_file = '/dfs/scratch1/simran/tacred/tacred-relation-bootleg/dataset_bootleg_cidr_model/bootleg_09132020/basic_full_sentences/all_tacred_w_bootoutput.jsonl'
pre_eval_df = load_mentions(pre_eval_candgen_file)


In [86]:
# for ind, row in pre_eval_df.iterrows():
#     print(row)
#     break

In [87]:
# ALL ALTERNATE NAME EXAMPLES
match_lists_preeval_extract = {'train': [], 'dev': [], 'test':[]}
mis_match_lists_preeval_extract = {'train': [], 'dev': [], 'test':[]}
for k,df in dfs.items():
    alternate_names_df = df[df['relation'].str.contains('org:alternate_names')]
    split_ids = alternate_names_df.loc[:,'id'].tolist()
    
    count_mismatch = 0
    err = 0
    mismatched_pair_preeval_extraction = []
    matched_qids_preeval_extraction = []
    
    print("{}".format(k), " has ", alternate_names_df.shape[0], " total alternate name examples.")
    #with open('{}_alternate_names_candgen.json'.format(k), 'w') as outfile:
    
    for idx in split_ids:
        pre_eval_row = pre_eval_df.loc[pre_eval_df['id'] == idx]
        row = alternate_names_df.loc[alternate_names_df['id'] ==idx]
        
        add = 0
        tokens = row['token'].item()
        ss, se = row['subj_start'].item(), row['subj_end'].item()
        subj = tokens[ss:se+1]
        os, oe = row['obj_start'].item(), row['obj_end'].item()
        obj = tokens[os:oe+1]

        has_subj_qid = 0
        has_obj_qid = 0
        for i in range(len(pre_eval_row['spans'].item())):
            span = pre_eval_row['spans'].item()[i]
            s, e = get_span(span)
            if s <= ss and e >= ss: # subj starts in this span
                qid_subj = pre_eval_row['qids'].item()[i]
                has_subj_qid = 1
            elif s <= os and e >= os: # obj starts in this span
                qid_obj = pre_eval_row['qids'].item()[i]
                has_obj_qid = 1
                
        if has_subj_qid and has_obj_qid: 
            if len(subj) == 1 or len(obj) == 1:
                if qid_subj != qid_obj:
                    add = 1
                    mismatched_pair_preeval_extraction.append((subj, obj))
                else:
                    matched_qids_preeval_extraction.append(qid_subj)
        else:
#             print(pre_eval_row['spans'])
#             print(ss, se, os, oe)
#             print()
            err += 1

        if add == 1:
            count_mismatch += 1
            
        #entry = {'acronym_mismatch': add, 'id':row['id'], 'subj':subj, 'obj':obj, 'example':tokens }
        #json.save(entry, outfile)
        #outfile.write('\n')
#     print("{}".format(k), " has ", count_mismatch, " total alternate name mismatch.")
#     print(err)
    match_lists_preeval_extract[k] = matched_qids_preeval_extraction
    mis_match_lists_preeval_extract[k] = mismatched_pair_preeval_extraction

train  has  808  total alternate name examples.




dev  has  348  total alternate name examples.
test  has  245  total alternate name examples.


In [88]:
# ACRONYM MATCHES IN EXTRACTION (preeval) 

for k, v in match_lists_preeval_extract.items():
    print("LENGTH OF matched qids list is: ", len(v))
    for item in v:
        print(item,qid_to_mention(item))
    print()

LENGTH OF matched qids list is:  173
Q19446 Blackburn Rovers F.C.
Q466587 American Psychological Association
Q466587 American Psychological Association
Q217595 Oxford University Press
Q3091339 Fyffes
Q1072857 International Crisis Group
Q1626261 Inter American Press Association
Q696658 Jewish National Fund
Q19446 Blackburn Rovers F.C.
Q326804 Korean Central News Agency
Q1072857 International Crisis Group
Q659854 Awami National Party
Q17427 Communist Party of China
Q1072857 International Crisis Group
Q659854 Awami National Party
Q466587 American Psychological Association
Q1626261 Inter American Press Association
Q1072857 International Crisis Group
Q466587 American Psychological Association
Q659854 Awami National Party
Q1072857 International Crisis Group
Q7804 International Monetary Fund
Q659854 Awami National Party
Q659854 Awami National Party
Q1072857 International Crisis Group
Q1056839 DirecTV
Q463465 American Bar Association
Q851756 United States National Security Council
Q326804 Kore

In [89]:
# ACRONYM MISMATCHES IN EXTRACTION (preeval) 
for k, v in mis_match_lists_preeval_extract.items():
    print("LENGTH OF mis-matched qids list is: ", len(v))
    for item in v:
        print(item)
    print()

LENGTH OF mis-matched qids list is:  393
(['MECO'], ['Manila', 'Economic', 'and', 'Cultural', 'Office'])
(['KKL'], ['Kinmen', 'Kaoliang', 'Liquor'])
(['Justice', 'and', 'Equality', 'Movement'], ['JEM'])
(['FAW'], ['First', 'Automobile', 'Works'])
(['Covidien'], ['Tyco', 'International'])
(['Iranian', 'Revolutionary', 'Guards', 'Corps-Quds'], ['IRGC-QF'])
(['ANP'], ['National', 'People', "'s", 'Army'])
(['ABA'], ['Major', 'League', 'Baseball'])
(['Scorpions'], ['DSO'])
(['Australian', 'Broadcasting', 'Corporation'], ['ABC'])
(['MAPS'], ['Plasma', 'Science'])
(['OIC'], ['Oracle', 'Incentive', 'Compensation'])
(['Covidien'], ['Tyco', 'Healthcare'])
(['Afghan', 'National', 'Police'], ['ANP'])
(['Justice', 'and', 'Equality', 'Movement'], ['JEM'])
(['All', 'Basotho', 'Convention'], ['ABC'])
(['Taiwan', 'Research', 'Institute'], ['TRI'])
(['CNB'], ['Czech', 'National', 'Bank'])
(['Chunghwa', 'Telecom'], ['CHT'])
(['NSG'], ['Nuclear', 'Suppliers', 'Group'])
(['Aerolineas', 'Argentinas'], ['AA'

# comparisons

### compare PREEVAL CANDGEN AND EXTRACTION (preeval) 

In [94]:
for key in mis_match_lists_preeval_extract.keys():
    print("KEY:", key)
    matched_qids_preeval_candgen = match_lists_preeval_candgen[key]
    matched_qids_preeval_extraction = match_lists_preeval_extract[key]
    
    print("These were qid matches in preeval candgen, but not in extract:")
    for item in set(matched_qids_preeval_candgen):
        if item not in set(matched_qids_preeval_extraction):
            print(item,qid_to_mention(item))
    print("\n")
    print("These were qid matches in extract, but not in preeval candgen:")
    for item in set(matched_qids_preeval_extraction):
        if item not in set(matched_qids_preeval_candgen):
            print(item,qid_to_mention(item)) 
    print("\n\n")

KEY: train
These were qid matches in preeval candgen, but not in extract:
Q190023 Central American Parliament
Q28699 UNITA
Q5440305 Federal Motor Carrier Safety Administration
Q213959 Boston Red Sox
Q140258 Zagat


These were qid matches in extract, but not in preeval candgen:
Q509404 Alcatel-Lucent
Q908172 Covidien
Q672885 Tyco International
Q1371852 MBIA
Q503673 Pentax
Q3091339 Fyffes



KEY: dev
These were qid matches in preeval candgen, but not in extract:
Q233358 PLOS


These were qid matches in extract, but not in preeval candgen:



KEY: test
These were qid matches in preeval candgen, but not in extract:
Q1186190 National Development and Reform Commission
Q190023 Central American Parliament


These were qid matches in extract, but not in preeval candgen:





### Compare EXTRACTION (preeval) and POSTEVAL BOOTLEG

In [95]:
for key in mis_match_lists_preeval_extract.keys():
    print("KEY:", key)
    matched_qids_posteval = match_lists_posteval[key]
    matched_qids_preeval_extraction = match_lists_preeval_extract[key]
    
    print("These were qid matches in posteval, but not in extract:")
    for item in set(matched_qids_posteval):
        if item not in set(matched_qids_preeval_extraction):
            print(item,qid_to_mention(item))
    print("\n")
    print("These were qid matches in extract, but not in posteval:")
    for item in set(matched_qids_preeval_extraction):
        if item not in set(matched_qids_posteval):
            print(item,qid_to_mention(item)) 
    print("\n\n")

KEY: train
These were qid matches in posteval, but not in extract:
Q3656204 Superior Electoral Court
Q1411292 Fidelity Investments
Q1160682 Mexican Football Federation
Q466524 American Psychiatric Association
Q656554 Carnival Cruise Line
Q5097775 Child Protective Services
Q1024804 United National Congress
Q7435413 Scorpions (South Africa)
Q6934663 Multidisciplinary Association for Psychedelic Studies
Q8568 Afghan National Police
Q2004149 Japanese Nuclear Safety Commission
Q463486 American Basketball Association
Q6978340 National Security Council (Pakistan)
Q465750 American Motorcyclist Association
Q1146616 Movement for Democratic Change – Tsvangirai
Q219203 NEC
Q781365 Australian Broadcasting Corporation
Q628461 Popular Resistance Committees
Q543115 International Skating Union
Q10876348 National Security Council (Taiwan)
Q852370 Justice and Equality Movement
Q20746304 National Electoral Commission (Tanzania)
Q680662 Austria Press Agency
Q1480793 Nuclear Suppliers Group
Q463436 American

### COMPARE POSTEVAL AND CANDGEN!

In [96]:
for key in mis_match_lists_preeval_extract.keys():
    print("KEY:", key)
    matched_qids_posteval = match_lists_posteval[key]
    matched_qids_preeval_candgen = match_lists_preeval_candgen[key]
    
    print("These were qid matches in posteval, but not in _preeval_candgen:")
    for item in set(matched_qids_posteval):
        if item not in set(matched_qids_preeval_candgen):
            print(item,qid_to_mention(item))
    print("\n")
    print("These were qid matches in _preeval_candgen, but not in posteval:")
    for item in set(matched_qids_preeval_candgen):
        if item not in set(matched_qids_posteval):
            print(item,qid_to_mention(item)) 
    print("\n\n")

KEY: train
These were qid matches in posteval, but not in _preeval_candgen:
Q3656204 Superior Electoral Court
Q509404 Alcatel-Lucent
Q908172 Covidien
Q1411292 Fidelity Investments
Q1160682 Mexican Football Federation
Q466524 American Psychiatric Association
Q656554 Carnival Cruise Line
Q672885 Tyco International
Q5097775 Child Protective Services
Q1024804 United National Congress
Q7435413 Scorpions (South Africa)
Q6934663 Multidisciplinary Association for Psychedelic Studies
Q1371852 MBIA
Q8568 Afghan National Police
Q2004149 Japanese Nuclear Safety Commission
Q463486 American Basketball Association
Q6978340 National Security Council (Pakistan)
Q465750 American Motorcyclist Association
Q1146616 Movement for Democratic Change – Tsvangirai
Q219203 NEC
Q781365 Australian Broadcasting Corporation
Q628461 Popular Resistance Committees
Q543115 International Skating Union
Q10876348 National Security Council (Taiwan)
Q503673 Pentax
Q852370 Justice and Equality Movement
Q20746304 National Elect