In [None]:
import os, sys 
import pandas as pd
pd.options.display.max_columns=500
pd.options.display.max_colwidth=1000
import matplotlib.pyplot as plt
from tqdm import tqdm 
tqdm.pandas()
import re 
from collections import Counter, defaultdict, OrderedDict
import seaborn as sns
from copy import deepcopy
from tabulate import tabulate, simple_separated_format
from inflection import singularize, pluralize
from df_to_latex import DataFrame2Latex
from collections import Counter, defaultdict
from inflection import singularize, pluralize

from utils_path import dataset_to_respath
# # Using WN, WN.taxonomy to retrieve path distance 
# - Pointers: 
#     - https://wn.readthedocs.io/en/latest/setup.html
#     - https://wn.readthedocs.io/en/latest/api/wn.html
# 
# - Installation: 
# ```
# !pip install wn
# !pip install wn[web]
# wn.download('ewn:2020')
# ```

# get the co-hyponyms from WordNet (Shick and Schutze, 2020)
# - get the hypernyms, maxinum d(x,y) is 2
# - get the top 2 most frequent senses of each hypernym 
# - get hyponyms of each hypernyms, maxinum distance d(y,z) is 4 
# - constrain the depeth of hypernyms to be 6

# In[201]:


import re 
import wn, wn.taxonomy
ewn = wn.Wordnet('ewn:2020')

def test_min_depth():
    for word in ['concept', 'thought', 'living thing', 'whole', 'psychological feature', 'unit', 'artifact', 'abstraction', 'object','physical entity', 'entity']:
        synset = wn.synsets(word, pos='n')[0]
        min_depth = wn.taxonomy.min_depth(synset, simulate_root=False)
        print(word, min_depth)


def get_inherited_hypernyms(word, k_synset, max_path_hyper, min_taxo_depth=6, print_flag=False):
    '''
    k_synset: the most frequent k_synset of word
    max_path_hyper: up to k level of hypernyms, e.g., 2 level higher than word 
    min_taxo_depth=6: concept, exluded hypernyms: unit, object, artifact, entity
    '''
    
    hyper_synsets = []
    for i, synset in enumerate(wn.synsets(word, pos='n')[:k_synset]): #top K senses of word 
        #print(f"{word} synset {i+1}")
        for j, path in enumerate(wn.taxonomy.hypernym_paths(synset)): #retrieve the hyper path for each synset 
            #print(f"path {j}")
            for i, ss in enumerate(path[:max_path_hyper]): # get the hypernyms within max_path_hyper
                ss_min_txo_depth = wn.taxonomy.min_depth(ss, simulate_root=False)
                
                if ss_min_txo_depth< min_taxo_depth: continue  #remove general concepts like "entity", 'physical entity'
                hyper_synsets.append(ss)
                if print_flag: 
                    print(' ' * i, ss, ss.lemmas()[0], ss_min_txo_depth)
                    
    return hyper_synsets

def get_direct_hyonyms(synsets):
    '''
    Return the direct hyponyms of a given list of synsets
    ''' 
    sister_synsets = []
    for synset in synsets: 
        sister_synsets.extend(synset.hyponyms() )
    return sister_synsets


def get_inherited_hyponyms(initial_synsets, max_path_hypo):
    synsets = initial_synsets
    synsets_hyponyms = []
    
    while max_path_hypo>0:
        synsets = get_direct_hyonyms(synsets)
        synsets_hyponyms.extend(synsets)
        max_path_hypo -=1
        #print(dist)
        #print(synsets)
        #print("-"*80)
    #print(Counter(synsets_all).most_common())
    return synsets_hyponyms


def filter_cohyponyms(word, synsets_cohyponyms, top_k=50):
    cohyponyms = []
    for synset in synsets_cohyponyms:
        for lemma in synset.lemmas():
            if lemma == word: continue 
            if len(lemma.split(" ")) >1 or len(lemma.split("-")) >1: continue 
            cohyponyms.append(lemma.lower())
    cohyponyms = Counter(cohyponyms)
     #if top_k !=None:        
    return cohyponyms.most_common(top_k)
    #else:
    #    return dict(cohyponyms.most_common())

    
def get_cohyponyms(word, top_k_cohyonyms=50, top_k_word_synset=2, max_path_hyper=2, max_path_hypo =4, print_flag=False):
    
    hyper_synsets = get_inherited_hypernyms(word, k_synset=top_k_word_synset, max_path_hyper = max_path_hyper)
    if print_flag:
        for synset in hyper_synsets:
            print(synset, synset.lemmas())

    synsets_cohyponyms = get_inherited_hyponyms(hyper_synsets, max_path_hypo= max_path_hypo)

    concept_cohyponyms = filter_cohyponyms(word, synsets_cohyponyms, top_k=top_k_cohyonyms)
    return list(dict(concept_cohyponyms).keys())
   

def test_get_cohyponyms(word, test_cohyponyms):
    '''
    word = 'corn'
    test_cohyponyms = ['bean', 'potato', 'barley', 'wheat', 'pea'] 
    word = 'train'
    test_cohyponyms = ['bus', 'plane', 'car', 'tram', 'truck']
    test_get_cohyponyms(word,test_cohyponyms )
    '''
    top_k_cohyonyms = None #200 
    top_k_word_synset = 2
    max_path_hyper = 2
    max_path_hypo = 4

    concept_cohyponyms  = get_cohyponyms(word, top_k_cohyonyms=top_k_cohyonyms, 
                                         top_k_word_synset=top_k_word_synset, 
                                         max_path_hyper=max_path_hyper, max_path_hypo = max_path_hypo)
    
    for query in test_cohyponyms:
        if query in concept_cohyponyms:
            print(query, 'yes')
        else:
            print(query, 'no')
    print(len(concept_cohyponyms), concept_cohyponyms)



# # Evaluation 

# In[225]:


def merge_predictions_in_concept_level(words, uniform_funcion=None, top_k=None ):
    '''
    uniform_function: either signualarize or pluralize 
    '''
    words_uniformed = [uniform_funcion(word) for word in words] if uniform_funcion !=None else words
    concepts = list(OrderedDict.fromkeys(words_uniformed))
    return concepts[:top_k] if top_k is not None else concepts

def concept_evaluation(label, pred):
    '''
    
    label: a list with the singualr and plural labels (e.g., ['tool', 'tools'])
    pred: the top K prediction list 

    return:
        1 if label share with pred else 0  
    '''
    if not isinstance(label, list):
        label = eval(label)
        
    if not isinstance(pred, list):
        pred = eval(pred)

    shared = set(label).intersection(set(pred))
    return 1 if len(shared)>0 else 0 
    # return len(shared)/len(pred)
    

def get_precision_at_k_concept(df, relation, pred_cols, label_col, k_list, pred_col_suffix='obj_mask_'):
    '''
    evalaute model predictions in concept level, ignoring the morphology affects (singular, plural)
    '''

    p_at_x = [] #defaultdict() 
    for pred_col in pred_cols: 
        suffix = pred_col.replace(pred_col_suffix, "")
        prec_cur = defaultdict()
        prec_cur['mask_type'] = suffix
        for k in k_list: 
            df[f'p{k}_{suffix}'] = df[[label_col, pred_col]].apply(lambda x: concept_evaluation(x[0], eval(x[1])[:k] if isinstance(x[1], str) else x[1][:k]), axis=1 )
            prec_cur[f'p@{k}'] = round(df[f'p{k}_{suffix}'].mean() , 3)*100

        p_at_x.append(prec_cur)  
        

    # aggregate the average precision across k 
    df_res = pd.DataFrame(p_at_x) #, columns=['mask_type', 'mAP'])
    df_res['relation'] = [relation]*len(df_res)
    return df_res

def get_highest_mrr_among_labels(label, pred):
    '''
    return the highest rank among the multiple labels. This is applicable to single labels as well, if we the single label is put in a list

    pred: a list of words (candidates)
    label: the true labels, which is a list (different forms of a word, e.g., singular or plurs, like animal and animals)
    '''
    mrr = 0 
    if pred is None: return mrr 

    rank_list = [ pred.index(item) + 1 for item in label if item in pred] 
    if len(rank_list)>0:
        mrr = 1/min(rank_list)

    return mrr 


def get_mrr(df, relation, pred_cols, label_col, pred_col_suffix):
    '''
    mrr is calculated based on the top_k rank, all elements in obj_col are used
    '''

    mrr = [] 
    for i, pred_col in enumerate(pred_cols):
        cur_mrr = defaultdict()
        suffix = pred_col.replace(pred_col_suffix, "")

        df[f'mrr_{suffix}'] = df[[label_col, pred_col]].apply(lambda x: get_highest_mrr_among_labels(x[0], x[1]), axis=1 ) 
        
        cur_mrr['mask_type'] = suffix
        cur_mrr[f"mrr"] = round(df[f'mrr_{suffix}'].mean(), 3)*100
        mrr.append(cur_mrr)

    mrr_df =  pd.DataFrame(data = mrr) #, columns=['mask_type', 'mrr'])
    # mrr_df['mask_type']= mrr_df['mask_type'].apply(lambda x: x.replace(""))
    mrr_df['relation'] = relation
    return mrr_df 


# In[202]:


def get_dataset_to_respath(dataset_to_respath, print_flag=False):
    # remote path 
#     dataset_to_respath = {'hypernymsuite-BLESS': 'log/bert-large-uncased/hypernymsuite/BLESS/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'lm_diagnostic_extended-singular': 'log/bert-large-uncased/lm_diagnostic_extended/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.LM_DIAGNOSTIC_EXTENDED.csv', 'clsb-singular': 'log/bert-large-uncased/clsb/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.CLSB.csv', 'hypernymsuite-LEDS': 'log/bert-large-uncased/hypernymsuite/LEDS/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'hypernymsuite-EVAL': 'log/bert-large-uncased/hypernymsuite/EVAL/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'hypernymsuite-SHWARTZ': 'log/bert-large-uncased/hypernymsuite/SHWARTZ/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv'}

    source_dir = 'spartan:~/cogsci/DAP/'
    target_dir = '../../'
    dataset_to_localpath = defaultdict()
    dataset_rename = {
        'hypernymsuite-BLESS': 'BLESS', 'lm_diagnostic_extended-singular': 'DIAG', 'clsb-singular':'CLSB', 'hypernymsuite-LEDS': 'LEDS', 'hypernymsuite-EVAL': 'EVAL', 'hypernymsuite-SHWARTZ': 
        "SHWARTZ"
    }
    for dataset, path in dataset_to_respath.items():
        path = path.replace(".tsv", ".csv")
        source_path = source_dir + path 
        dataset_l1 = dataset.split("-")[0]
        dataset_l2 = dataset.split("-")[1] 
        target_path = target_dir + path
        scp_string = f"!scp {source_path} {target_path}"
        if print_flag:
            print(scp_string)
            print()
#         print(target_path)
        dataset_to_localpath[dataset_rename[dataset]] = target_path 
#     print(dataset_to_localpath)
    return dataset_to_localpath
dataset_to_localpath = get_dataset_to_respath(dataset_to_respath)


from nltk.corpus import wordnet 
wn_lemmas = set(wordnet.all_lemma_names())
def check_word_in_wordnet(word, wn_lemmas):
    '''
    1 if word in wordnet else 0
    '''
    return 1 if word in wn_lemmas else 0 


def get_all_vocab(dataset_to_localpath):
    dataset_to_df = defaultdict()
    vocab_sub = set()
    for dataset, path in dataset_to_localpath.items(): 
        if debug:
           if dataset!='DIAG': continue 
        print("dataset", dataset)
        df = pd.read_csv(path)
        vocab_sub.update(df['sub_label_sg'].to_list())
        print(len(vocab_sub))
    return list(vocab_sub)




## Config for getting co-hyponyms from WordNet
top_k_cohyonyms = None #200 
top_k_word_synset = 2
max_path_hyper = 2
max_path_hypo = 4

#config for evaluation 
pred_col_suffix=''
label_col = 'sub_sister_new'
pred_cols = ['subj_anchors_all_sg']
relation='co-hyponyms'
debug= False #True #eval(sys.argv[1])
print("debug", debug)

# dataset = sys.argv[2]
# path = sys.argv[3]

for dataset, path in dataset_to_localpath.items(): 
    if dataset!='SHWARTZ':continue
    vocab_sub = set()
    if debug:
       if dataset!='DIAG': continue 
    print("dataset", dataset)
    df = pd.read_csv(path)
    vocab_sub.update(df['sub_label_sg'].to_list())
    print(len(vocab_sub))

    word_to_cohyponyms = defaultdict(list)
    for i, word in enumerate(vocab_sub):
        if i%10==0: print(i)
        if not check_word_in_wordnet(word, wn_lemmas):
            word_to_cohyponyms[word] = []
        else:
            cohyponyms = get_cohyponyms(word, top_k_cohyonyms=top_k_cohyonyms, 
                                     top_k_word_synset=top_k_word_synset, 
                                     max_path_hyper=max_path_hyper, 
                                     max_path_hypo = max_path_hypo)
            word_to_cohyponyms[word] = cohyponyms
        if debug:
            print(word, cohyponyms)
            
    if i%100==0:
        df = pd.DataFrame(word_to_cohyponyms.items(), columns=['word', 'cohyponyms'])
        output_path = f'../log/{dataset}_word_to_cohyponyms.txt'
        df.to_csv(output_path, index=False)
        print(f"save {output_path}")


debug False
dataset SHWARTZ
11061
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
1980
1990
2000
2010
2020
2030
2040
2050
2060
2070
2080
2090
2100
2110
2120
2130
2140
21

In [None]:
df = pd.DataFrame(word_to_cohyponyms.items(), columns=['word', 'cohyponyms'])
output_path = f'../log/{dataset}_word_to_cohyponyms.txt'
df.to_csv(output_path, index=False)
print(f"save {output_path}")


In [33]:
for dataset, path in dataset_to_localpath.items(): 
    print("python -u get_cohyponyms_dataset.py ", dataset, path)

python -u get_cohyponyms_dataset.py  BLESS ../../log/bert-large-uncased/hypernymsuite/BLESS/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_5_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_False_cpt_False.HYPERNYMSUITE.csv
python -u get_cohyponyms_dataset.py  DIAG ../../log/bert-large-uncased/lm_diagnostic_extended/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_5_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_False_cpt_False.LM_DIAGNOSTIC_EXTENDED.csv
python -u get_cohyponyms_dataset.py  CLSB ../../log/bert-large-uncased/clsb/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_5_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_False_cpt_False.CLSB.csv
python -u get_cohyponyms_dataset.py  LEDS ../../log/bert-large-uncased/hypernymsuite/LEDS/exp_data_results_anchor_typ

In [29]:
from utils_path import dataset_to_respath 

def get_dataset_to_respath(dataset_to_respath, print_flag=False):
    # remote path 
#     dataset_to_respath = {'hypernymsuite-BLESS': 'log/bert-large-uncased/hypernymsuite/BLESS/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'lm_diagnostic_extended-singular': 'log/bert-large-uncased/lm_diagnostic_extended/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.LM_DIAGNOSTIC_EXTENDED.csv', 'clsb-singular': 'log/bert-large-uncased/clsb/singular/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.CLSB.csv', 'hypernymsuite-LEDS': 'log/bert-large-uncased/hypernymsuite/LEDS/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'hypernymsuite-EVAL': 'log/bert-large-uncased/hypernymsuite/EVAL/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv', 'hypernymsuite-SHWARTZ': 'log/bert-large-uncased/hypernymsuite/SHWARTZ/exp_data_results_anchor_type_Coordinate_remove_Y_PUNC_FULL_concate_or_single_max_anchor_num_10_anchor_scorer_probAvg_filter_obj_True_filter_objects_with_input_True_wnp_True_cpt_False.HYPERNYMSUITE.csv'}

    source_dir = 'spartan:~/cogsci/DAP/'
    target_dir = '../../'
    dataset_to_localpath = defaultdict()
    dataset_rename = {
        'hypernymsuite-BLESS': 'BLESS', 'lm_diagnostic_extended-singular': 'DIAG', 'clsb-singular':'CLSB', 'hypernymsuite-LEDS': 'LEDS', 'hypernymsuite-EVAL': 'EVAL', 'hypernymsuite-SHWARTZ': 
        "SHWARTZ"
    }
    for dataset, path in dataset_to_respath.items():
        path = path.replace(".tsv", ".csv")
        source_path = source_dir + path 
        dataset_l1 = dataset.split("-")[0]
        dataset_l2 = dataset.split("-")[1] 
        target_path = target_dir + path
        scp_string = f"!scp {source_path} {target_path}"
        if print_flag:
            print(scp_string)
            print()
#         print(target_path)
        dataset_to_localpath[dataset_rename[dataset]] = target_path 
#     print(dataset_to_localpath)
    return dataset_to_localpath
dataset_to_localpath = get_dataset_to_respath(dataset_to_respath)


from nltk.corpus import wordnet 
wn_lemmas = set(wordnet.all_lemma_names())
def check_word_in_wordnet(word, wn_lemmas):
    '''
    1 if word in wordnet else 0
    '''
    return 1 if word in wn_lemmas else 0 


def get_all_vocab(dataset_to_localpath):
    dataset_to_df = defaultdict()
    vocab = set()
    for dataset, path in dataset_to_localpath.items(): 
        if debug:
           if dataset!='DIAG': continue 
        print("dataset", dataset)
        df = pd.read_csv(path)
        df['subj_anchors'] = df['subj_anchors'].apply(lambda x: eval(x))
        for anchors in df['subj_anchors'].to_list():
            vocab.update(anchors)
        print(len(vocab))
    return list(vocab)

#config for evaluation 
pred_col_suffix=''
label_col = 'sub_sister_new'
pred_cols = ['subj_anchors_all_sg']
relation='co-hyponyms'
debug= False #True #eval(sys.argv[1])
print("debug", debug)
from inflection import singularize, pluralize 

vocab = get_all_vocab(dataset_to_localpath)
print("#vocab", len(vocab), vocab[:5])
word_to_singular = defaultdict(list)
for i, word in enumerate(vocab):
    if i%1000==0: print(i)
    word_to_singular[word] = singularize(word)
    if debug:
        print(word, word_to_singular)
       

df = pd.DataFrame(word_to_singular.items(), columns=['word', 'singular'])
output_path = '../log/word_to_singular.txt'
df.to_csv(output_path, index=False)
print(f"save {output_path}")

debug False
dataset BLESS
626
dataset DIAG
1561
dataset CLSB
1962
dataset LEDS
3782
dataset EVAL
4250
dataset SHWARTZ
8351
#vocab 8351 ['spa', 'poles', 'lovers', 'moderate', 'talk']
0
1000
2000
3000
4000
5000
6000
7000
8000
save ../log/word_to_singular.txt
