In [1]:
import json
import jsonlines
from tqdm.auto import tqdm
from collections import defaultdict
from copy import deepcopy

import numpy as np

from cooccurrence_matrix import CooccurrenceMatrix

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pile_coo_matrix = CooccurrenceMatrix('pile')
bert_coo_matrix = CooccurrenceMatrix('bert_pretraining_data')

In [3]:
from nltk.corpus import stopwords
from nltk import word_tokenize

stopword_list = stopwords.words("english")

filter = {}
for w in stopword_list:
    filter[w] = w
punctuations = {
    "?": "?",
    ":": ":",
    "!": "!",
    ".": ".",
    ",": ",",
    ";": ";"
}
filter.update(punctuations)
def filtering(text):
    if text in filter:
        return True

def text_normalization_without_lemmatization(text):
    result = []
    tokens = word_tokenize(text)
    
    for token in tokens:
        token_low = token.lower()
        if filtering(token_low):
            continue
        result.append(token_low)
    return result

In [8]:
dataset_name = 'LAMA_TREx'
dataset_type = 'test'

training_type = 'zeroshot'

In [11]:
with open(f"../../../data/{dataset_name}/all.json", 'r') as fin:
    f_all = json.load(fin)

uid_rel_map, uid_subj_map, uid_obj_map = {}, {}, {}
for example in f_all:
    uid_subj_map[example['uid']] = example['subj']
    uid_rel_map[example['uid']] = example['rel_id']
    uid_obj_map[example['uid']] = example['output']

In [37]:
num_sections = 10

def prob_value_to_section(value):
    return min(int(np.ceil(-np.log10(value+0.000001))), num_sections - 1)

def prob_section_to_string(section):
        denominator = str(10**section) if section < num_sections - 1 else 'inf'
        return '1/'+denominator

In [43]:
model_name_dict = {
    # 'bert-base-uncased': 'BERT%_{base}%',
    # 'bert-large-uncased': 'BERT$_{large}$',
    # 'albert-base-v1': 'ALBERT1$_{base}$',
    # 'albert-large-v1': 'ALBERT1$_{large}$',
    # 'albert-xlarge-v1': 'ALBERT1$_{xlarge}$',
    # 'albert-base-v2': 'ALBERT2$_{base}$',
    # 'albert-large-v2': 'ALBERT2$_{large}$',
    # 'albert-xlarge-v2': 'ALBERT2$_{xlarge}$',
    # 'roberta-base': 'RoBERTa$_{base}$',
    # 'roberta-large': 'RoBERTa$_{large}$',
    'gpt-neo-125m': 'GPT-Neo 125M',
    # 'gpt-neo-1.3B': 'GPT-Neo 1.3B',
    # 'gpt-neo-2.7B': 'GPT-Neo 2.7B',
    # 'gpt-j-6b': 'GPT-J 6B',
    # 'gpt-3.5-turbo-0125': 'ChatGPT-3.5',
    # 'gpt-4-0125-preview': 'ChatGPT-4'
}

for model_name in model_name_dict.keys():
    print('='*30)
    print('='*30)
    print('Model:', model_name)

    try:
        data = jsonlines.open(f'../../../results/{dataset_name}/{model_name}_{dataset_name}_{training_type}/pred_{dataset_name}_{dataset_type}.jsonl')
    except:
        raise Exception
        # continue

    if 'gpt' in model_name:
        coo_matrix = pile_coo_matrix
        num_total_samples = 254188957
    else:
        coo_matrix = bert_coo_matrix
        num_total_samples = 158887337

    openai_api = True if 'gpt-3.5-turbo' in model_name or 'gpt-4-0125' in model_name else False

    results_hits_1, results_hits_10, results_hits_100 = defaultdict(list), defaultdict(list), defaultdict(list)
    rel_results_hits_1, rel_results_hits_10, rel_results_hits_100 = defaultdict(dict), defaultdict(dict), defaultdict(dict)

    for pred in tqdm(data.iter()):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        obj = uid_obj_map[pred['uid']]
        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj = ' '.join(text_normalization_without_lemmatization(obj))
        
        subj_count = coo_matrix.count(subj)
        obj_count = coo_matrix.count(obj)
        subj_obj_count = coo_matrix.coo_count(subj, obj)

        # skip if the count is -1 (unknown)
        if subj_obj_count < 0:
            continue

        subj_prob = subj_count / num_total_samples
        joint_prob = subj_obj_count / num_total_samples
        cond_prob = subj_obj_count / subj_count if subj_count > 0 else 0

        prob = joint_prob

        section = prob_value_to_section(prob)

        results_hits_1[section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            results_hits_10[section].append(pred['hits@10_remove_stopwords'])
            results_hits_100[section].append(pred['hits@100_remove_stopwords'])

        if section not in rel_results_hits_1[rel]:
            rel_results_hits_1[rel][section] = []
            rel_results_hits_10[rel][section] = []
            rel_results_hits_100[rel][section] = []
        rel_results_hits_1[rel][section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            rel_results_hits_10[rel][section].append(pred['hits@10_remove_stopwords'])
            rel_results_hits_100[rel][section].append(pred['hits@100_remove_stopwords'])

    num_samples = {}
    sections = range(num_sections)
    sorted_rels = sorted(list(rel_results_hits_1.keys()))
    for section in sections:
        num_samples[section] = len(results_hits_1[section])

        if section in results_hits_1:
            results_hits_1[section] = np.mean(results_hits_1[section]), np.std(results_hits_1[section])
            results_hits_10[section] = np.mean(results_hits_10[section]), np.std(results_hits_10[section])
            results_hits_100[section] = np.mean(results_hits_100[section]), np.std(results_hits_100[section])

        # for rel in rel_results_hits_1:
        #     if section in rel_results_hits_1[rel]:
        #         rel_results_hits_1[rel][section] = np.mean(rel_results_hits_1[rel][section]), np.std(rel_results_hits_1[rel][section])
        #         rel_results_hits_10[rel][section] = np.mean(rel_results_hits_10[rel][section]), np.std(rel_results_hits_10[rel][section])
        #         rel_results_hits_100[rel][section] = np.mean(rel_results_hits_100[rel][section]), np.std(rel_results_hits_100[rel][section])

    result = {}
    for section in sections:
        if section in results_hits_1:
            result[f'hits@1_remove_stopwords_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_1[section]
    
    # for section in sections:
    #     if section in results_hits_10:
    #         result[f'hits@10_remove_stopwords_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_10[section]

    for section in sections:
        if section in results_hits_100:
            result[f'hits@100_remove_stopwords_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_100[section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_1[rel]:
    #             result[f'hits_1_remove_stopwords_{rel}_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_1[rel][section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_10[rel]:
    #             result[f'hits_10_remove_stopwords_{rel}_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_10[rel][section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_100[rel]:
    #             result[f'hits_100_remove_stopwords_{rel}_section_{prob_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_100[rel][section]
    
    print(num_samples)
    print(json.dumps(result, indent=4))


Model: gpt-neo-125m


0it [00:00, ?it/s]

8824it [00:04, 2150.79it/s]

{0: 0, 1: 0, 2: 3, 3: 119, 4: 707, 5: 1352, 6: 6114, 7: 0, 8: 0, 9: 0}
{
    "hits@1_remove_stopwords_section_1/1": "nan +- nan",
    "hits@1_remove_stopwords_section_1/10": "nan +- nan",
    "hits@1_remove_stopwords_section_1/100": "0.00 +- 0.00",
    "hits@1_remove_stopwords_section_1/1000": "0.17 +- 0.37",
    "hits@1_remove_stopwords_section_1/10000": "0.09 +- 0.28",
    "hits@1_remove_stopwords_section_1/100000": "0.15 +- 0.36",
    "hits@1_remove_stopwords_section_1/1000000": "0.10 +- 0.30",
    "hits@1_remove_stopwords_section_1/10000000": "nan +- nan",
    "hits@1_remove_stopwords_section_1/100000000": "nan +- nan",
    "hits@1_remove_stopwords_section_1/inf": "nan +- nan",
    "hits@100_remove_stopwords_section_1/1": "nan +- nan",
    "hits@100_remove_stopwords_section_1/10": "nan +- nan",
    "hits@100_remove_stopwords_section_1/100": "1.00 +- 0.00",
    "hits@100_remove_stopwords_section_1/1000": "0.82 +- 0.39",
    "hits@100_remove_stopwords_section_1/10000": "0.77 +- 0.42"


