In [None]:
import json
import jsonlines
from tqdm.auto import tqdm
from collections import defaultdict
from copy import deepcopy

import numpy as np

from cooccurrence_matrix import CooccurrenceMatrix

In [None]:
pile_coo_matrix = CooccurrenceMatrix('pile')
bert_coo_matrix = CooccurrenceMatrix('bert_pretraining_data')

In [None]:
from nltk.corpus import stopwords
from nltk import word_tokenize

stopword_list = stopwords.words("english")

filter = {}
for w in stopword_list:
    filter[w] = w
punctuations = {
    "?": "?",
    ":": ":",
    "!": "!",
    ".": ".",
    ",": ",",
    ";": ";"
}
filter.update(punctuations)
def filtering(text):
    if text in filter:
        return True

def text_normalization_without_lemmatization(text):
    result = []
    tokens = word_tokenize(text)
    
    for token in tokens:
        token_low = token.lower()
        if filtering(token_low):
            continue
        result.append(token_low)
    return result

In [None]:
model_name_dict = {
    'bert-base-uncased': 'BERT$_{base}$',
    # 'bert-large-uncased': 'BERT$_{large}$',
    # 'albert-base-v1': 'ALBERT1$_{base}$',
    # 'albert-large-v1': 'ALBERT1$_{large}$',
    # 'albert-xlarge-v1': 'ALBERT1$_{xlarge}$',
    # 'albert-base-v2': 'ALBERT2$_{base}$',
    # 'albert-large-v2': 'ALBERT2$_{large}$',
    # 'albert-xlarge-v2': 'ALBERT2$_{xlarge}$',
    # 'roberta-base': 'RoBERTa$_{base}$',
    # 'roberta-large': 'RoBERTa$_{large}$',
    # 'gpt-neo-125m': 'GPT-Neo 125M',
    # 'gpt-neo-1.3B': 'GPT-Neo 1.3B',
    # 'gpt-neo-2.7B': 'GPT-Neo 2.7B',
    'gpt-j-6b': 'GPT-J 6B',
    # 'gpt-3.5-turbo-0125': 'ChatGPT-3.5',
    # 'gpt-4-0125-preview': 'ChatGPT-4'
}

In [None]:
dataset_name = 'ConceptNet'
dataset_type = 'test'

training_type = 'zeroshot'

In [None]:
with open(f"../../../data/{dataset_name}/all.json", 'r') as fin:
    f_all = json.load(fin)

uid_rel_map = {}
uid_subj_map = {}
rel_subj_objects = defaultdict(set)
for example in f_all:
    subj = example['subj']
    rel = example['rel_id']
    obj = example['output']

    uid_subj_map[example['uid']] = subj
    uid_rel_map[example['uid']] = rel
    rel_subj_objects[rel+'_'+subj].add(obj.lower())
for key in rel_subj_objects:
    rel_subj_objects[key] = list(rel_subj_objects[key])

In [None]:
num_sections = 8
def prob_value_to_section(value):
    return min(int(np.ceil(-np.log2(value+0.000001))), num_sections - 1)
    
for model_name in model_name_dict.keys():
    try:
        data = jsonlines.open(f'../../../results/{dataset_name}/{model_name}_{dataset_name}_{training_type}/pred_{dataset_name}_{dataset_type}.jsonl')
    except:
        raise Exception
        # continue

    if 'gpt' in model_name:
        coo_matrix = pile_coo_matrix
    else:
        coo_matrix = bert_coo_matrix

    print('='*30)
    print('='*30)
    print('Model:', model_name)

    condprob_gt_bin_total = defaultdict(list)
    condprob_pred_bin_total = defaultdict(list)
    condprob_gt_bin_success = defaultdict(list)
    condprob_pred_bin_success = defaultdict(list)
    condprob_gt_bin_failure = defaultdict(list)
    condprob_pred_bin_failure = defaultdict(list)

    count_bin_failure = defaultdict(list)

    for pred in tqdm(data.iter()):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        label_text = pred['label_text'].lower()
        rel_subj_object = deepcopy(rel_subj_objects[rel+'_'+subj])
        rel_subj_object.remove(label_text)

        if 'top_100_text_remove_stopwords' in pred:
            pred_top_k_remove_stopwords = pred['top_100_text_remove_stopwords']
        else:
            pred_top_k_remove_stopwords = pred['top_5_text_remove_stopwords']
        
        # we remove other valid objects for a subject-relation pair other than the one we test
        for w in pred_top_k_remove_stopwords:
            w = w.lower().strip()
            if w not in rel_subj_object or True:
                pred_top_1_remove_stopwords = w
                break

        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj_gt = ' '.join(text_normalization_without_lemmatization(label_text))
        obj_pred = ' '.join(text_normalization_without_lemmatization(pred_top_1_remove_stopwords))
        joint_freq_gt = coo_matrix.coo_count(subj, obj_gt)
        joint_freq_pred = coo_matrix.coo_count(subj, obj_pred)
        
        subj_freq = coo_matrix.count(subj)
        # skip if the entities are composed of more than 3 tokens, or are stopwords
        if joint_freq_gt <= 0 or joint_freq_pred <= 0 or subj_freq <= 0:
            continue
        cond_prob_gt = joint_freq_gt / subj_freq if subj_freq > 0 else 0
        cond_prob_pred = joint_freq_pred / subj_freq if subj_freq > 0 else 0

        bin = prob_value_to_section(cond_prob_gt)

        condprob_gt_bin_total[bin].append(cond_prob_gt)
        condprob_pred_bin_total[bin].append(cond_prob_pred)
        condprob_gt_bin_total['total'].append(cond_prob_gt)
        condprob_pred_bin_total['total'].append(cond_prob_pred)
        if pred['hits@1_remove_stopwords'] > 0.5:
            condprob_gt_bin_success[bin].append(cond_prob_gt)
            condprob_pred_bin_success[bin].append(cond_prob_pred)
            condprob_gt_bin_success['total'].append(cond_prob_gt)
            condprob_pred_bin_success['total'].append(cond_prob_pred)
        else:
            condprob_gt_bin_failure[bin].append(cond_prob_gt)
            condprob_pred_bin_failure[bin].append(cond_prob_pred)
            condprob_gt_bin_failure['total'].append(cond_prob_gt)
            condprob_pred_bin_failure['total'].append(cond_prob_pred)
            count_bin_failure[bin].append((cond_prob_pred > cond_prob_gt)*1)
            count_bin_failure['total'].append((cond_prob_pred > cond_prob_gt)*1)

    # print('Total')
    # for bin in ['total'] + list(range(num_sections)):
    #     print(f"{bin} / {round(np.mean(condprob_pred_bin_total[bin]), 2)} +- {round(np.std(condprob_pred_bin_total[bin]), 2) } / {round(np.mean(condprob_gt_bin_total[bin]), 2)} +- {round(np.std(condprob_gt_bin_total[bin]), 2)} / {len(condprob_pred_bin_total[bin])}")
    print('Count in failure cases')
    for bin in ['total'] + list(range(num_sections)):
        try:
            print(f"{bin} / {int(np.mean(count_bin_failure[bin])*100)}% / {len(count_bin_failure[bin])}")
        except:
            print(bin)
    print('Failure cases')
    for bin in ['total'] + list(range(num_sections)):
        try:
            print(f"{bin} / {round(np.mean(condprob_pred_bin_failure[bin]), 2)} +- {round(np.std(condprob_pred_bin_failure[bin]), 2) } / {round(np.mean(condprob_gt_bin_failure[bin]), 2)} +- {round(np.std(condprob_gt_bin_failure[bin]), 2)} / {len(condprob_gt_bin_failure[bin])}")
        except:
            print(bin)
    

In [None]:
dataset_name = 'ConceptNet'
dataset_type = 'test'

training_type = 'prompt_tuning'

In [None]:
num_sections = 8
def prob_value_to_section(value):
    return min(int(np.ceil(-np.log2(value+0.000001))), num_sections - 1)
    
for model_name in model_name_dict.keys():
    try:
        data = jsonlines.open(f'../../../results/{dataset_name}/{model_name}_{dataset_name}_{training_type}/pred_{dataset_name}_{dataset_type}.jsonl')
    except:
        raise Exception
        # continue

    if 'gpt' in model_name:
        coo_matrix = pile_coo_matrix
    else:
        coo_matrix = bert_coo_matrix

    print('='*30)
    print('='*30)
    print('Model:', model_name)

    condprob_gt_bin_total = defaultdict(list)
    condprob_pred_bin_total = defaultdict(list)
    condprob_gt_bin_success = defaultdict(list)
    condprob_pred_bin_success = defaultdict(list)
    condprob_gt_bin_failure = defaultdict(list)
    condprob_pred_bin_failure = defaultdict(list)

    count_bin_failure = defaultdict(list)

    for pred in tqdm(data.iter()):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        label_text = pred['label_text'].lower()
        rel_subj_object = deepcopy(rel_subj_objects[rel+'_'+subj])
        rel_subj_object.remove(label_text)

        if 'top_100_text_remove_stopwords' in pred:
            pred_top_k_remove_stopwords = pred['top_100_text_remove_stopwords']
        else:
            pred_top_k_remove_stopwords = pred['top_5_text_remove_stopwords']
        
        # we remove other valid objects for a subject-relation pair other than the one we test
        for w in pred_top_k_remove_stopwords:
            w = w.lower().strip()
            if w not in rel_subj_object or True:
                pred_top_1_remove_stopwords = w
                break

        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj_gt = ' '.join(text_normalization_without_lemmatization(label_text))
        obj_pred = ' '.join(text_normalization_without_lemmatization(pred_top_1_remove_stopwords))
        joint_freq_gt = coo_matrix.coo_count(subj, obj_gt)
        joint_freq_pred = coo_matrix.coo_count(subj, obj_pred)
        
        subj_freq = coo_matrix.count(subj)
        # skip if the entities are composed of more than 3 tokens, or are stopwords
        if joint_freq_gt <= 0 or joint_freq_pred <= 0 or subj_freq <= 0:
            continue
        cond_prob_gt = joint_freq_gt / subj_freq if subj_freq > 0 else 0
        cond_prob_pred = joint_freq_pred / subj_freq if subj_freq > 0 else 0

        bin = prob_value_to_section(cond_prob_gt)

        condprob_gt_bin_total[bin].append(cond_prob_gt)
        condprob_pred_bin_total[bin].append(cond_prob_pred)
        condprob_gt_bin_total['total'].append(cond_prob_gt)
        condprob_pred_bin_total['total'].append(cond_prob_pred)
        if pred['hits@1_remove_stopwords'] > 0.5:
            condprob_gt_bin_success[bin].append(cond_prob_gt)
            condprob_pred_bin_success[bin].append(cond_prob_pred)
            condprob_gt_bin_success['total'].append(cond_prob_gt)
            condprob_pred_bin_success['total'].append(cond_prob_pred)
        else:
            condprob_gt_bin_failure[bin].append(cond_prob_gt)
            condprob_pred_bin_failure[bin].append(cond_prob_pred)
            condprob_gt_bin_failure['total'].append(cond_prob_gt)
            condprob_pred_bin_failure['total'].append(cond_prob_pred)
            count_bin_failure[bin].append((cond_prob_pred > cond_prob_gt)*1)
            count_bin_failure['total'].append((cond_prob_pred > cond_prob_gt)*1)

    # print('Total')
    # for bin in ['total'] + list(range(num_sections)):
    #     print(f"{bin} / {round(np.mean(condprob_pred_bin_total[bin]), 2)} +- {round(np.std(condprob_pred_bin_total[bin]), 2) } / {round(np.mean(condprob_gt_bin_total[bin]), 2)} +- {round(np.std(condprob_gt_bin_total[bin]), 2)} / {len(condprob_pred_bin_total[bin])}")
    print('Count in failure cases')
    for bin in ['total'] + list(range(num_sections)):
        try:
            print(f"{bin} / {int(np.mean(count_bin_failure[bin])*100)}% / {len(count_bin_failure[bin])}")
        except:
            print(bin)
    print('Failure cases')
    for bin in ['total'] + list(range(num_sections)):
        try:
            print(f"{bin} / {round(np.mean(condprob_pred_bin_failure[bin]), 2)} +- {round(np.std(condprob_pred_bin_failure[bin]), 2) } / {round(np.mean(condprob_gt_bin_failure[bin]), 2)} +- {round(np.std(condprob_gt_bin_failure[bin]), 2)} / {len(condprob_gt_bin_failure[bin])}")
        except:
            print(bin)
    