In [1]:
import json
from tqdm.auto import tqdm
from collections import defaultdict
from copy import deepcopy
from scipy import stats
import random

import numpy as np

from cooccurrence_matrix import CooccurrenceMatrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
coo_matrix = CooccurrenceMatrix()

In [3]:
from nltk.corpus import stopwords
from nltk import word_tokenize

stopword_list = stopwords.words("english")

filter = {}
for w in stopword_list:
    filter[w] = w
punctuations = {
    "?": "?",
    ":": ":",
    "!": "!",
    ".": ".",
    ",": ",",
    ";": ";"
}
filter.update(punctuations)
def filtering(text):
    if text in filter:
        return True

def text_normalization_without_lemmatization(text):
    result = []
    tokens = word_tokenize(text)
    
    for token in tokens:
        token_low = token.lower()
        if filtering(token_low):
            continue
        result.append(token_low)
    return result

In [4]:
with open("../data/LAMA_TREx/all.json", 'r') as fin:
    f_all = json.load(fin)

uid_rel_map = {}
uid_subj_map = {}
rel_subj_objects = defaultdict(set)
for example in f_all:
    subj = example['subj']
    rel = example['rel_id']
    obj = example['output']

    uid_subj_map[example['uid']] = subj
    uid_rel_map[example['uid']] = rel
    rel_subj_objects[rel+'_'+subj].add(obj.lower())
for key in rel_subj_objects:
    rel_subj_objects[key] = list(rel_subj_objects[key])

In [5]:
model_names = ['gpt_neo_125M', 'gpt_neo_125M_TREx', 'gpt_neo_1_3B', 'gpt_neo_1_3B_TREx',
               'gpt_neo_2_7B', 'gpt_neo_2_7B_TREx', 'gpt_j_6B', 'gpt_j_6B_TREx']
openai_model_names = ['text-davinci-003', 'gpt-3.5-turbo-0301', 'gpt-4-0314']

num_sections = 8
def prob_value_to_section(value):
    return min(int(np.ceil(-np.log2(value+0.000001))), num_sections - 1)
    
for model_name in model_names + openai_model_names:
    if model_name in openai_model_names:
        pred_test_filename = "../results/" + model_name + "/pred_possible_only.json"
    else:
        pred_test_filename = "../results/" + model_name + "/pred_factual_probing_test.json"

    print('='*30)
    print('='*30)
    print('Model:', model_name)
    
    with open(pred_test_filename, "r") as fin:
        preds = json.load(fin)

    print('Test')

    hits_1 = []
    mrr = []
    cond_probs = []
    bins = []

    for pred in tqdm(preds):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        label_text = pred['label_text'].lower()
        rel_subj_object = deepcopy(rel_subj_objects[rel+'_'+subj])
        rel_subj_object.remove(label_text)

        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj_gt = ' '.join(text_normalization_without_lemmatization(label_text))
        joint_freq_gt = coo_matrix.coo_count(subj, obj_gt)
        # skip if the entities are composed of more than 3 tokens, or are stopwords
        if joint_freq_gt <= 0:
            continue
        subj_freq = coo_matrix.count(subj)
        cond_prob_gt = joint_freq_gt / subj_freq if subj_freq > 0 else 0

        bin = prob_value_to_section(cond_prob_gt)

        hits_1.append(pred['hits@1_remove_stopwords'])
        if 'mrr_remove_stopwords' in pred:
            mrr.append(pred['mrr_remove_stopwords'])
        else:
            mrr.append(pred['hits@1_remove_stopwords'])
        cond_probs.append(cond_prob_gt)
        bins.append(bin)

    without_binning = stats.pearsonr(hits_1, cond_probs)
    with_binning = stats.pearsonr(hits_1, bins)
    mrr_without_binning = stats.pearsonr(mrr, cond_probs)
    mrr_with_binning = stats.pearsonr(mrr, bins)
    rand_cond_probs = deepcopy(cond_probs)
    random.shuffle(rand_cond_probs)
    without_binning_random = stats.pearsonr(hits_1, rand_cond_probs)
    mrr_without_binning_random = stats.pearsonr(mrr, rand_cond_probs)
    print(f"{round(without_binning.statistic, 2)} / {round(with_binning.statistic, 2)} / {round(without_binning_random.statistic, 2)} /// {without_binning.pvalue} / {with_binning.pvalue} / {without_binning_random.pvalue}")
    print(f"{round(mrr_without_binning.statistic, 2)} / {round(mrr_with_binning.statistic, 2)} / {round(mrr_without_binning_random.statistic, 2)} /// {mrr_without_binning.pvalue} / {mrr_with_binning.pvalue} / {mrr_without_binning_random.pvalue}")

Model: gpt_neo_125M
Test


100%|██████████| 8856/8856 [00:01<00:00, 7439.55it/s]


0.22 / -0.2 / 0.01 /// 1.1397398619484083e-93 / 2.775762805404166e-76 / 0.4370290297062604
0.23 / -0.23 / 0.01 /// 4.977877270809893e-104 / 1.0239910559297291e-96 / 0.40498484681004615
Model: gpt_neo_125M_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 7418.51it/s]


0.35 / -0.35 / -0.01 /// 2.1955324688252254e-238 / 7.945094184582398e-245 / 0.5823302815066153
0.38 / -0.41 / -0.01 /// 1.565625273172974e-288 / 0.0 / 0.4396947470969913
Model: gpt_neo_1_3B
Test


100%|██████████| 8856/8856 [00:01<00:00, 7442.50it/s]


0.21 / -0.21 / 0.02 /// 2.4716382325051189e-85 / 2.2312264792659015e-80 / 0.16455429009143466
0.26 / -0.26 / 0.01 /// 3.722179165546501e-126 / 7.996857678860146e-124 / 0.21534224411649433
Model: gpt_neo_1_3B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 7394.41it/s]


0.35 / -0.36 / -0.01 /// 9.112245426681854e-234 / 2.898140799098851e-250 / 0.25400566699525146
0.4 / -0.43 / -0.01 /// 2.5726e-319 / 0.0 / 0.1915752302154333
Model: gpt_neo_2_7B
Test


100%|██████████| 8856/8856 [00:01<00:00, 7451.55it/s]


0.21 / -0.21 / -0.02 /// 3.3737819012337525e-84 / 4.720919317879833e-80 / 0.11379673090262642
0.26 / -0.26 / -0.02 /// 6.692195336818876e-132 / 3.826945925784204e-128 / 0.0802481979571421
Model: gpt_neo_2_7B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 7457.81it/s]


0.35 / -0.36 / 0.01 /// 9.029723248441689e-238 / 1.23232041575901e-254 / 0.28339167242622504
0.4 / -0.43 / 0.02 /// 1.1364e-320 / 0.0 / 0.16751144351006717
Model: gpt_j_6B
Test


100%|██████████| 8856/8856 [00:01<00:00, 7404.03it/s]


0.24 / -0.24 / -0.01 /// 4.3241639482454355e-107 / 4.182139153425179e-106 / 0.6406932401649517
0.29 / -0.29 / -0.01 /// 5.357333939435433e-164 / 9.505979065991612e-166 / 0.5323561992988697
Model: gpt_j_6B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 7323.16it/s]


0.35 / -0.36 / -0.02 /// 4.320316945232727e-232 / 3.618007885726267e-254 / 0.05945469068305892
0.4 / -0.43 / -0.01 /// 9.012e-320 / 0.0 / 0.26012075414909097
Model: text-davinci-003
Test


100%|██████████| 4149/4149 [00:00<00:00, 7414.65it/s]


0.25 / -0.23 / 0.01 /// 2.18628861601269e-56 / 1.0324329451855343e-45 / 0.5601093281775997
0.25 / -0.23 / 0.01 /// 2.18628861601269e-56 / 1.0324329451855343e-45 / 0.5601093281775997
Model: gpt-3.5-turbo-0301
Test


100%|██████████| 4149/4149 [00:00<00:00, 7096.84it/s]


0.21 / -0.21 / -0.02 /// 6.569170608324027e-41 / 3.482007750886441e-38 / 0.19354403512890325
0.21 / -0.21 / -0.02 /// 6.569170608324027e-41 / 3.482007750886441e-38 / 0.19354403512890325
Model: gpt-4-0314
Test


100%|██████████| 4149/4149 [00:00<00:00, 7211.34it/s]

0.22 / -0.2 / -0.01 /// 1.1353469735347542e-43 / 2.2795861983220642e-37 / 0.4532476412737201
0.22 / -0.2 / -0.01 /// 1.1353469735347542e-43 / 2.2795861983220642e-37 / 0.4532476412737201



