In [1]:
import json
from tqdm.auto import tqdm
from collections import defaultdict
from copy import deepcopy

import numpy as np

from cooccurrence_matrix import CooccurrenceMatrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
coo_matrix = CooccurrenceMatrix()

In [3]:
from nltk.corpus import stopwords
from nltk import word_tokenize

stopword_list = stopwords.words("english")

filter = {}
for w in stopword_list:
    filter[w] = w
punctuations = {
    "?": "?",
    ":": ":",
    "!": "!",
    ".": ".",
    ",": ",",
    ";": ";"
}
filter.update(punctuations)
def filtering(text):
    if text in filter:
        return True

def text_normalization_without_lemmatization(text):
    result = []
    tokens = word_tokenize(text)
    
    for token in tokens:
        token_low = token.lower()
        if filtering(token_low):
            continue
        result.append(token_low)
    return result

In [4]:
with open("../data/LAMA_TREx/all.json", 'r') as fin:
    f_all = json.load(fin)

uid_rel_map = {}
uid_subj_map = {}
rel_subj_objects = defaultdict(set)
for example in f_all:
    subj = example['subj']
    rel = example['rel_id']
    obj = example['output']

    uid_subj_map[example['uid']] = subj
    uid_rel_map[example['uid']] = rel
    rel_subj_objects[rel+'_'+subj].add(obj.lower())
for key in rel_subj_objects:
    rel_subj_objects[key] = list(rel_subj_objects[key])

In [7]:
model_names = ['gpt_neo_125M', 'gpt_neo_125M_TREx', 'gpt_neo_1_3B', 'gpt_neo_1_3B_TREx',
               'gpt_neo_2_7B', 'gpt_neo_2_7B_TREx', 'gpt_j_6B', 'gpt_j_6B_TREx']
openai_model_names = ['text-davinci-003', 'gpt-3.5-turbo-0301', 'gpt-4-0314']

num_sections = 8
def prob_value_to_section(value):
    return min(int(np.ceil(-np.log2(value+0.000001))), num_sections - 1)
    
for model_name in model_names + openai_model_names:
    if model_name in openai_model_names:
        pred_test_filename = "../results/" + model_name + "/pred_possible_only.json"
    else:
        pred_test_filename = "../results/" + model_name + "/pred_factual_probing_test.json"

    print('='*30)
    print('='*30)
    print('Model:', model_name)
    
    with open(pred_test_filename, "r") as fin:
        preds = json.load(fin)

    print('Test')

    condprob_gt_bin_total = defaultdict(list)
    condprob_pred_bin_total = defaultdict(list)
    condprob_gt_bin_success = defaultdict(list)
    condprob_pred_bin_success = defaultdict(list)
    condprob_gt_bin_failure = defaultdict(list)
    condprob_pred_bin_failure = defaultdict(list)

    count_bin_failure = defaultdict(list)

    for pred in tqdm(preds):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        label_text = pred['label_text'].lower()
        rel_subj_object = deepcopy(rel_subj_objects[rel+'_'+subj])
        rel_subj_object.remove(label_text)

        if model_name in openai_model_names:
            pred_top_1_remove_stopwords = pred['top_1_text_remove_stopwords']
        else:
            # we remove other valid objects for a subject-relation pair other than the one we test
            for w in pred['top_100_text_remove_stopwords']:
                w = w.lower().strip()
                if w not in rel_subj_object or True:
                    pred_top_1_remove_stopwords = w
                    break

        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj_gt = ' '.join(text_normalization_without_lemmatization(label_text))
        obj_pred = ' '.join(text_normalization_without_lemmatization(pred_top_1_remove_stopwords))
        joint_freq_gt = coo_matrix.coo_count(subj, obj_gt)
        joint_freq_pred = coo_matrix.coo_count(subj, obj_pred)
        # skip if the entities are composed of more than 3 tokens, or are stopwords
        if joint_freq_gt <= 0 or joint_freq_pred <= 0:
            continue
        subj_freq = coo_matrix.count(subj)
        cond_prob_gt = joint_freq_gt / subj_freq if subj_freq > 0 else 0
        cond_prob_pred = joint_freq_pred / subj_freq if subj_freq > 0 else 0

        bin = prob_value_to_section(cond_prob_gt)

        condprob_gt_bin_total[bin].append(cond_prob_gt)
        condprob_pred_bin_total[bin].append(cond_prob_pred)
        condprob_gt_bin_total['total'].append(cond_prob_gt)
        condprob_pred_bin_total['total'].append(cond_prob_pred)
        if pred['hits@1_remove_stopwords'] > 0.5:
            condprob_gt_bin_success[bin].append(cond_prob_gt)
            condprob_pred_bin_success[bin].append(cond_prob_pred)
            condprob_gt_bin_success['total'].append(cond_prob_gt)
            condprob_pred_bin_success['total'].append(cond_prob_pred)
        else:
            condprob_gt_bin_failure[bin].append(cond_prob_gt)
            condprob_pred_bin_failure[bin].append(cond_prob_pred)
            condprob_gt_bin_failure['total'].append(cond_prob_gt)
            condprob_pred_bin_failure['total'].append(cond_prob_pred)
            count_bin_failure[bin].append((cond_prob_pred > cond_prob_gt)*1)
            count_bin_failure['total'].append((cond_prob_pred > cond_prob_gt)*1)

    # print('Total')
    # for bin in ['total'] + list(range(num_sections)):
    #     print(f"{bin} / {round(np.mean(condprob_pred_bin_total[bin]), 2)} +- {round(np.std(condprob_pred_bin_total[bin]), 2) } / {round(np.mean(condprob_gt_bin_total[bin]), 2)} +- {round(np.std(condprob_gt_bin_total[bin]), 2)} / {len(condprob_pred_bin_total[bin])}")
    print('Count in failure cases')
    for bin in ['total'] + list(range(num_sections)):
        print(f"{bin} / {int(np.mean(count_bin_failure[bin])*100)}% / {len(count_bin_failure[bin])}")
    print('Failure cases')
    for bin in ['total'] + list(range(num_sections)):
        print(f"{bin} / {round(np.mean(condprob_pred_bin_failure[bin]), 2)} +- {round(np.std(condprob_pred_bin_failure[bin]), 2) } / {round(np.mean(condprob_gt_bin_failure[bin]), 2)} +- {round(np.std(condprob_gt_bin_failure[bin]), 2)} / {len(condprob_gt_bin_failure[bin])}")
    

Model: gpt_neo_125M
Test


100%|██████████| 8856/8856 [00:01<00:00, 5466.48it/s]


Count in failure cases
total / 37% / 6094
0 / 0% / 647
1 / 14% / 2074
2 / 39% / 1487
3 / 63% / 888
4 / 79% / 481
5 / 87% / 244
6 / 94% / 147
7 / 97% / 126
Failure cases
total / 0.35 +- 0.26 / 0.48 +- 0.32 / 6094
0 / 0.37 +- 0.28 / 1.0 +- 0.0 / 647
1 / 0.37 +- 0.27 / 0.72 +- 0.14 / 2074
2 / 0.35 +- 0.24 / 0.37 +- 0.07 / 1487
3 / 0.34 +- 0.25 / 0.18 +- 0.04 / 888
4 / 0.31 +- 0.26 / 0.09 +- 0.02 / 481
5 / 0.3 +- 0.28 / 0.05 +- 0.01 / 244
6 / 0.31 +- 0.27 / 0.02 +- 0.0 / 147
7 / 0.32 +- 0.31 / 0.01 +- 0.0 / 126
Model: gpt_neo_125M_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 5565.74it/s]


Count in failure cases
total / 31% / 3906
0 / 0% / 91
1 / 6% / 1319
2 / 18% / 935
3 / 39% / 601
4 / 60% / 414
5 / 79% / 255
6 / 89% / 164
7 / 91% / 127
Failure cases
total / 0.23 +- 0.26 / 0.4 +- 0.29 / 3906
0 / 0.4 +- 0.32 / 1.0 +- 0.0 / 91
1 / 0.24 +- 0.25 / 0.71 +- 0.14 / 1319
2 / 0.23 +- 0.24 / 0.37 +- 0.07 / 935
3 / 0.23 +- 0.26 / 0.18 +- 0.03 / 601
4 / 0.21 +- 0.24 / 0.09 +- 0.02 / 414
5 / 0.23 +- 0.28 / 0.05 +- 0.01 / 255
6 / 0.23 +- 0.27 / 0.02 +- 0.0 / 164
7 / 0.22 +- 0.28 / 0.01 +- 0.0 / 127
Model: gpt_neo_1_3B
Test


100%|██████████| 8856/8856 [00:01<00:00, 5494.96it/s]


Count in failure cases
total / 39% / 5715
0 / 0% / 634
1 / 14% / 1897
2 / 41% / 1329
3 / 63% / 829
4 / 82% / 467
5 / 85% / 257
6 / 90% / 172
7 / 97% / 130
Failure cases
total / 0.36 +- 0.27 / 0.47 +- 0.32 / 5715
0 / 0.37 +- 0.28 / 1.0 +- 0.0 / 634
1 / 0.38 +- 0.27 / 0.72 +- 0.14 / 1897
2 / 0.36 +- 0.25 / 0.37 +- 0.07 / 1329
3 / 0.34 +- 0.26 / 0.18 +- 0.04 / 829
4 / 0.31 +- 0.27 / 0.09 +- 0.02 / 467
5 / 0.33 +- 0.3 / 0.05 +- 0.01 / 257
6 / 0.34 +- 0.33 / 0.02 +- 0.0 / 172
7 / 0.3 +- 0.3 / 0.01 +- 0.0 / 130
Model: gpt_neo_1_3B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 5465.39it/s]


Count in failure cases
total / 28% / 3525
0 / 0% / 104
1 / 6% / 1108
2 / 19% / 854
3 / 33% / 542
4 / 50% / 378
5 / 65% / 243
6 / 75% / 166
7 / 83% / 130
Failure cases
total / 0.2 +- 0.23 / 0.39 +- 0.3 / 3525
0 / 0.35 +- 0.29 / 1.0 +- 0.0 / 104
1 / 0.23 +- 0.25 / 0.71 +- 0.14 / 1108
2 / 0.22 +- 0.23 / 0.37 +- 0.07 / 854
3 / 0.18 +- 0.2 / 0.18 +- 0.03 / 542
4 / 0.15 +- 0.18 / 0.09 +- 0.02 / 378
5 / 0.13 +- 0.18 / 0.05 +- 0.01 / 243
6 / 0.12 +- 0.19 / 0.02 +- 0.0 / 166
7 / 0.13 +- 0.23 / 0.01 +- 0.0 / 130
Model: gpt_neo_2_7B
Test


100%|██████████| 8856/8856 [00:01<00:00, 5480.78it/s]


Count in failure cases
total / 37% / 5801
0 / 0% / 610
1 / 14% / 1971
2 / 37% / 1385
3 / 61% / 812
4 / 77% / 463
5 / 85% / 264
6 / 94% / 169
7 / 96% / 127
Failure cases
total / 0.35 +- 0.28 / 0.47 +- 0.32 / 5801
0 / 0.4 +- 0.28 / 1.0 +- 0.0 / 610
1 / 0.37 +- 0.27 / 0.72 +- 0.14 / 1971
2 / 0.35 +- 0.26 / 0.37 +- 0.07 / 1385
3 / 0.33 +- 0.27 / 0.18 +- 0.04 / 812
4 / 0.31 +- 0.27 / 0.09 +- 0.02 / 463
5 / 0.33 +- 0.3 / 0.05 +- 0.01 / 264
6 / 0.33 +- 0.34 / 0.02 +- 0.0 / 169
7 / 0.32 +- 0.33 / 0.01 +- 0.0 / 127
Model: gpt_neo_2_7B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 5498.96it/s]


Count in failure cases
total / 29% / 3351
0 / 0% / 79
1 / 7% / 1041
2 / 20% / 808
3 / 35% / 528
4 / 51% / 379
5 / 58% / 243
6 / 68% / 155
7 / 83% / 118
Failure cases
total / 0.21 +- 0.24 / 0.38 +- 0.29 / 3351
0 / 0.39 +- 0.32 / 1.0 +- 0.0 / 79
1 / 0.24 +- 0.26 / 0.72 +- 0.13 / 1041
2 / 0.23 +- 0.24 / 0.37 +- 0.07 / 808
3 / 0.2 +- 0.23 / 0.18 +- 0.03 / 528
4 / 0.17 +- 0.22 / 0.09 +- 0.02 / 379
5 / 0.12 +- 0.18 / 0.05 +- 0.01 / 243
6 / 0.1 +- 0.19 / 0.02 +- 0.0 / 155
7 / 0.08 +- 0.17 / 0.01 +- 0.0 / 118
Model: gpt_j_6B
Test


100%|██████████| 8856/8856 [00:01<00:00, 5299.83it/s]


Count in failure cases
total / 38% / 5252
0 / 0% / 479
1 / 15% / 1742
2 / 42% / 1251
3 / 56% / 782
4 / 70% / 440
5 / 78% / 262
6 / 85% / 167
7 / 95% / 129
Failure cases
total / 0.35 +- 0.29 / 0.46 +- 0.32 / 5252
0 / 0.42 +- 0.31 / 1.0 +- 0.0 / 479
1 / 0.38 +- 0.28 / 0.72 +- 0.14 / 1742
2 / 0.37 +- 0.27 / 0.37 +- 0.07 / 1251
3 / 0.31 +- 0.26 / 0.18 +- 0.04 / 782
4 / 0.29 +- 0.29 / 0.09 +- 0.02 / 440
5 / 0.3 +- 0.31 / 0.05 +- 0.01 / 262
6 / 0.26 +- 0.32 / 0.02 +- 0.0 / 167
7 / 0.26 +- 0.3 / 0.01 +- 0.0 / 129
Model: gpt_j_6B_TREx
Test


100%|██████████| 8856/8856 [00:01<00:00, 5329.41it/s]


Count in failure cases
total / 29% / 3267
0 / 0% / 80
1 / 6% / 1013
2 / 20% / 770
3 / 35% / 519
4 / 53% / 363
5 / 63% / 237
6 / 71% / 157
7 / 71% / 128
Failure cases
total / 0.21 +- 0.24 / 0.38 +- 0.29 / 3267
0 / 0.4 +- 0.31 / 1.0 +- 0.0 / 80
1 / 0.24 +- 0.25 / 0.71 +- 0.14 / 1013
2 / 0.23 +- 0.23 / 0.37 +- 0.07 / 770
3 / 0.2 +- 0.23 / 0.18 +- 0.04 / 519
4 / 0.16 +- 0.19 / 0.09 +- 0.02 / 363
5 / 0.13 +- 0.18 / 0.05 +- 0.01 / 237
6 / 0.13 +- 0.21 / 0.02 +- 0.0 / 157
7 / 0.09 +- 0.2 / 0.01 +- 0.0 / 128
Model: text-davinci-003
Test


100%|██████████| 4149/4149 [00:00<00:00, 5593.12it/s]


Count in failure cases
total / 34% / 1786
0 / 0% / 135
1 / 18% / 655
2 / 41% / 473
3 / 50% / 300
4 / 64% / 130
5 / 66% / 59
6 / 80% / 25
7 / 100% / 9
Failure cases
total / 0.35 +- 0.27 / 0.48 +- 0.3 / 1786
0 / 0.4 +- 0.3 / 1.0 +- 0.0 / 135
1 / 0.42 +- 0.28 / 0.72 +- 0.13 / 655
2 / 0.35 +- 0.26 / 0.37 +- 0.07 / 473
3 / 0.26 +- 0.23 / 0.19 +- 0.04 / 300
4 / 0.24 +- 0.24 / 0.09 +- 0.02 / 130
5 / 0.17 +- 0.21 / 0.05 +- 0.01 / 59
6 / 0.11 +- 0.11 / 0.02 +- 0.0 / 25
7 / 0.08 +- 0.07 / 0.01 +- 0.0 / 9
Model: gpt-3.5-turbo-0301
Test


100%|██████████| 4149/4149 [00:00<00:00, 5514.72it/s]


Count in failure cases
total / 26% / 1785
0 / 0% / 127
1 / 11% / 581
2 / 28% / 464
3 / 39% / 344
4 / 45% / 153
5 / 58% / 67
6 / 62% / 35
7 / 78% / 14
Failure cases
total / 0.23 +- 0.26 / 0.44 +- 0.3 / 1785
0 / 0.32 +- 0.26 / 1.0 +- 0.0 / 127
1 / 0.28 +- 0.3 / 0.71 +- 0.13 / 581
2 / 0.23 +- 0.25 / 0.37 +- 0.07 / 464
3 / 0.21 +- 0.23 / 0.18 +- 0.04 / 344
4 / 0.16 +- 0.21 / 0.09 +- 0.02 / 153
5 / 0.13 +- 0.17 / 0.05 +- 0.01 / 67
6 / 0.08 +- 0.09 / 0.02 +- 0.0 / 35
7 / 0.03 +- 0.03 / 0.01 +- 0.0 / 14
Model: gpt-4-0314
Test


100%|██████████| 4149/4149 [00:00<00:00, 5371.08it/s]

Count in failure cases
total / 28% / 1748
0 / 0% / 121
1 / 15% / 573
2 / 25% / 468
3 / 43% / 355
4 / 52% / 136
5 / 56% / 55
6 / 80% / 26
7 / 92% / 14
Failure cases
total / 0.26 +- 0.3 / 0.45 +- 0.3 / 1748
0 / 0.33 +- 0.28 / 1.0 +- 0.0 / 121
1 / 0.29 +- 0.33 / 0.72 +- 0.13 / 573
2 / 0.26 +- 0.31 / 0.37 +- 0.07 / 468
3 / 0.23 +- 0.26 / 0.19 +- 0.04 / 355
4 / 0.26 +- 0.31 / 0.09 +- 0.02 / 136
5 / 0.2 +- 0.25 / 0.05 +- 0.01 / 55
6 / 0.17 +- 0.19 / 0.02 +- 0.0 / 26
7 / 0.11 +- 0.19 / 0.01 +- 0.0 / 14



