In [None]:
import json
import jsonlines
from tqdm.auto import tqdm
from collections import defaultdict
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np

from cooccurrence_matrix import CooccurrenceMatrix

In [None]:
pile_coo_matrix = CooccurrenceMatrix('pile')
bert_coo_matrix = CooccurrenceMatrix('bert_pretraining_data')

In [None]:
from nltk.corpus import stopwords
from nltk import word_tokenize

stopword_list = stopwords.words("english")

filter = {}
for w in stopword_list:
    filter[w] = w
punctuations = {
    "?": "?",
    ":": ":",
    "!": "!",
    ".": ".",
    ",": ",",
    ";": ";"
}
filter.update(punctuations)
def filtering(text):
    if text in filter:
        return True

def text_normalization_without_lemmatization(text):
    result = []
    tokens = word_tokenize(text)
    
    for token in tokens:
        token_low = token.lower()
        if filtering(token_low):
            continue
        result.append(token_low)
    return result

In [None]:
dataset_name = 'ConceptNet'
dataset_type = 'test'

training_type = 'prompt_tuning'

In [None]:
with open(f"../../../data/{dataset_name}/all.json", 'r') as fin:
    f_all = json.load(fin)

uid_rel_map, uid_subj_map, uid_obj_map = {}, {}, {}
for example in f_all:
    uid_subj_map[example['uid']] = example['subj']
    uid_rel_map[example['uid']] = example['rel_id']
    uid_obj_map[example['uid']] = example['output']

In [None]:
bins = [0, 10, 100, 1000, 10000, 100000]

def frequency_to_section(value):
    return np.digitize(value, bins)

def frequency_section_to_string(section):
    return f'{section}'

In [None]:
import matplotlib.pyplot as plt
import matplotlib

# Scale factor for fonts
scale_factor = 1.5

# Update default font sizes
plt.rcParams.update({
    'font.size': 12 * scale_factor,
    'axes.labelsize': 14 * scale_factor,  # x and y labels from plt.xlabel and plt.ylabel
    'axes.titlesize': 16 * scale_factor,  # title from plt.title
    'xtick.labelsize': 12 * scale_factor,  # x tick labels
    'ytick.labelsize': 12 * scale_factor,  # y tick labels
    'legend.fontsize': 12 * scale_factor,  # legend font size
    'figure.titlesize': 18 * scale_factor  # suptitle
})

######################################################################

# dataset = 'bert'
dataset = 'pile'

bin_edges = np.logspace(0, 6, 7)  # Create bin edges for log scale
bert_counts = [111695, 26484, 14963, 4586, 1256, 560]
pile_counts = [24663, 30469, 35552, 32754, 22536, 13570]

counts = {
    'bert': bert_counts,
    'pile': pile_counts
}

# Create the bar chart
fig, ax = plt.subplots(figsize=(10, 6))

# Create a bar chart
# The 'align=edge' and the bin_edges[:-1] aligns bars starting from the left edge
bars = ax.bar(np.power(10, np.arange(len(bin_edges[:-1]))+0.1), counts[dataset], width=np.diff(bin_edges)*0.8, align='edge', color='#17a2b8')

# Set x-axis to logarithmic scales
ax.set_xscale('log')
plt.xticks(bin_edges, labels=[f'$10^{i}$' for i in range(len(bin_edges))])

# remove minor ticks
plt.tick_params(axis='x', which='minor', length=0)

# Setting labels (adjust as needed)
ax.set_xlabel('Joint frequency of subject and object')
ax.set_ylabel('Number of triples')

# Annotate each bar with the count
for bar, x_pos in zip(bars, np.power(10, np.arange(len(bin_edges[:-1]))+0.5)):
    yval = bar.get_height()
    ax.text(x_pos, yval, f'{yval:,}', va='bottom', ha='center', wrap=True)

# Set the labels and title
ax.get_yaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))

plt.ylim(0, np.max(counts[dataset])*1.1)

# Show the plot
plt.tight_layout()  # Adjust layout to fit all labels
plt.savefig(f'results/number_of_samples_conceptnet_{dataset}.pdf', bbox_inches='tight', format='pdf')
plt.show()


In [None]:
training_type = 'prompt_tuning'

In [None]:
model_name_dict = {
    'bert-base-uncased': 'BERT$_{base}$',
    'bert-large-uncased': 'BERT$_{large}$',
    # 'albert-base-v1': 'ALBERT1$_{base}$',
    # 'albert-large-v1': 'ALBERT1$_{large}$',
    # 'albert-xlarge-v1': 'ALBERT1$_{xlarge}$',
    # 'albert-base-v2': 'ALBERT2$_{base}$',
    # 'albert-large-v2': 'ALBERT2$_{large}$',
    # 'albert-xlarge-v2': 'ALBERT2$_{xlarge}$',
    # 'roberta-base': 'RoBERTa$_{base}$',
    # 'roberta-large': 'RoBERTa$_{large}$',
    # 'gpt-neo-125m': 'GPT-Neo 125M',
    # 'gpt-neo-1.3B': 'GPT-Neo 1.3B',
    # 'gpt-neo-2.7B': 'GPT-Neo 2.7B',
    'gpt-j-6b': 'GPT-J 6B',
    # 'gpt-3.5-turbo-0125': 'ChatGPT-3.5',
    # 'gpt-4-0125-preview': 'ChatGPT-4'
}

colors = {
    'bert-base-uncased': 'tab:blue',
    'bert-large-uncased': 'tab:green',
    'gpt-neo-125m': 'tab:orange',
    'gpt-j-6b': 'tab:red',
}
markers = {
    'bert-base-uncased': 'o',
    'bert-large-uncased': '^',
    'gpt-neo-125m': 's',
    'gpt-j-6b': 'D',
}

# Scale factor for fonts
scale_factor = 1.5

# Update default font sizes
plt.rcParams.update({
    'font.size': 12 * scale_factor,
    'axes.labelsize': 14 * scale_factor,  # x and y labels from plt.xlabel and plt.ylabel
    'axes.titlesize': 16 * scale_factor,  # title from plt.title
    'xtick.labelsize': 12 * scale_factor,  # x tick labels
    'ytick.labelsize': 12 * scale_factor,  # y tick labels
    'legend.fontsize': 12 * scale_factor,  # legend font size
    'figure.titlesize': 18 * scale_factor  # suptitle
})

# Fixed x-axis values - the positions where the x-tick labels will be placed
x_tick_labels = [1, 10, 100, 1000, 10000, 100000, 1000000]
# Calculate midpoints for plotting the data points
x_values = np.sqrt(np.array(x_tick_labels[:-1]) * np.array(x_tick_labels[1:]))

# Create a figure and a set of subplots
fig, ax1 = plt.subplots(figsize=(10, 6))

for model_name in model_name_dict.keys():
    print('='*30)
    print('='*30)
    print('Model:', model_name)

    try:
        data = jsonlines.open(f'../../../results/{dataset_name}/{model_name}_{dataset_name}_{training_type}/pred_{dataset_name}_{dataset_type}.jsonl')
    except:
        raise Exception
        # continue

    if 'gpt' in model_name:
        coo_matrix = pile_coo_matrix
        num_total_samples = 254188957
    else:
        coo_matrix = bert_coo_matrix
        num_total_samples = 158887337

    openai_api = True if 'gpt-3.5-turbo' in model_name or 'gpt-4-0125' in model_name else False

    results_hits_1, results_hits_10, results_hits_100 = defaultdict(list), defaultdict(list), defaultdict(list)
    rel_results_hits_1, rel_results_hits_10, rel_results_hits_100 = defaultdict(dict), defaultdict(dict), defaultdict(dict)

    for pred in tqdm(data.iter()):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        obj = uid_obj_map[pred['uid']]
        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj = ' '.join(text_normalization_without_lemmatization(obj))
        
        subj_count = coo_matrix.count(subj)
        obj_count = coo_matrix.count(obj)
        subj_obj_count = coo_matrix.coo_count(subj, obj)

        # skip if the count is -1 (unknown)
        if subj_obj_count < 0:
            continue

        subj_prob = subj_count / num_total_samples
        joint_prob = subj_obj_count / num_total_samples
        cond_prob = subj_obj_count / subj_count if subj_count > 0 else 0

        freq = subj_obj_count
        section = frequency_to_section(freq)

        results_hits_1[section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            results_hits_10[section].append(pred['hits@10_remove_stopwords'])
            results_hits_100[section].append(pred['hits@100_remove_stopwords'])

        if section not in rel_results_hits_1[rel]:
            rel_results_hits_1[rel][section] = []
            rel_results_hits_10[rel][section] = []
            rel_results_hits_100[rel][section] = []
        rel_results_hits_1[rel][section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            rel_results_hits_10[rel][section].append(pred['hits@10_remove_stopwords'])
            rel_results_hits_100[rel][section].append(pred['hits@100_remove_stopwords'])

    num_samples = {}
    sections = range(1, len(bins)+1)
    sorted_rels = sorted(list(rel_results_hits_1.keys()))
    for section in sections:
        num_samples[section] = len(results_hits_1[section])

        if section in results_hits_1:
            results_hits_1[section] = np.mean(results_hits_1[section]), np.std(results_hits_1[section])
            results_hits_10[section] = np.mean(results_hits_10[section]), np.std(results_hits_10[section])
            results_hits_100[section] = np.mean(results_hits_100[section]), np.std(results_hits_100[section])

        # for rel in rel_results_hits_1:
        #     if section in rel_results_hits_1[rel]:
        #         rel_results_hits_1[rel][section] = np.mean(rel_results_hits_1[rel][section]), np.std(rel_results_hits_1[rel][section])
        #         rel_results_hits_10[rel][section] = np.mean(rel_results_hits_10[rel][section]), np.std(rel_results_hits_10[rel][section])
        #         rel_results_hits_100[rel][section] = np.mean(rel_results_hits_100[rel][section]), np.std(rel_results_hits_100[rel][section])

    result = {}
    for section in sections:
        if section in results_hits_1:
            result[f'hits@1_remove_stopwords_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_1[section]
    
    # for section in sections:
    #     if section in results_hits_10:
    #         result[f'hits@10_remove_stopwords_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_10[section]

    for section in sections:
        if section in results_hits_100:
            result[f'hits@100_remove_stopwords_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % results_hits_100[section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_1[rel]:
    #             result[f'hits_1_remove_stopwords_{rel}_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_1[rel][section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_10[rel]:
    #             result[f'hits_10_remove_stopwords_{rel}_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_10[rel][section]

    # for section in sections:
    #     for rel in sorted_rels:
    #         if section in rel_results_hits_100[rel]:
    #             result[f'hits_100_remove_stopwords_{rel}_section_{frequency_section_to_string(section)}'] = f'%.2f +- %.2f' % rel_results_hits_100[rel][section]

    print(num_samples)
    # print(json.dumps(result, indent=4))

    hits_100_mean = [results_hits_100[section][0] for section in sections]
    hits_100_std = [results_hits_100[section][1] for section in sections]
    # Plotting line plots for Hits@100
    ax1.plot(x_values, hits_100_mean, marker=markers[model_name], color=colors[model_name], linestyle='-', label=model_name_dict[model_name])
    
# Set x-axis to a logarithmic scale
plt.xscale('log')
plt.xticks(x_tick_labels, labels=[f'$10^{i}$' for i in range(len(x_tick_labels))])

# remove minor ticks
plt.tick_params(axis='x', which='minor', length=0)

# Setting the x-axis label
plt.xlabel('Joint frequency of subject and object')
# Setting the y-axis label for the first y-axis
ax1.set_ylabel('Hits@100', color='black')
# Set the limits for the y-axis if necessary
ax1.set_ylim(0, 1)

# Adding a legend for the line plots
ax1.legend()

# Show the plot
# plt.title('Model Performance Comparison')
filename = f'results/{dataset_name}_{dataset_type}_{training_type}_hits@100_against_jointprob.pdf'
plt.tight_layout()  # Adjust layout to fit all labels
plt.savefig(filename, format='pdf')
plt.show()

In [None]:
training_type = 'prompt_tuning'

In [None]:
import seaborn as sns

model_name_dict = {
    'bert-base-uncased': 'BERT$_{base}$',
    # 'bert-large-uncased': 'BERT$_{large}$',
    # 'albert-base-v1': 'ALBERT1$_{base}$',
    # 'albert-large-v1': 'ALBERT1$_{large}$',
    # 'albert-xlarge-v1': 'ALBERT1$_{xlarge}$',
    # 'albert-base-v2': 'ALBERT2$_{base}$',
    # 'albert-large-v2': 'ALBERT2$_{large}$',
    # 'albert-xlarge-v2': 'ALBERT2$_{xlarge}$',
    # 'roberta-base': 'RoBERTa$_{base}$',
    # 'roberta-large': 'RoBERTa$_{large}$',
    # 'gpt-neo-125m': 'GPT-Neo 125M',
    # 'gpt-neo-1.3B': 'GPT-Neo 1.3B',
    # 'gpt-neo-2.7B': 'GPT-Neo 2.7B',
    # 'gpt-j-6b': 'GPT-J 6B',
    # 'gpt-3.5-turbo-0125': 'ChatGPT-3.5',
    # 'gpt-4-0125-preview': 'ChatGPT-4'
}

# Scale factor for fonts
scale_factor = 1.5

# Update default font sizes
plt.rcParams.update({
    'font.size': 12 * scale_factor,
    'axes.labelsize': 14 * scale_factor,  # x and y labels from plt.xlabel and plt.ylabel
    'axes.titlesize': 16 * scale_factor,  # title from plt.title
    'xtick.labelsize': 12 * scale_factor,  # x tick labels
    'ytick.labelsize': 12 * scale_factor,  # y tick labels
    'legend.fontsize': 12 * scale_factor,  # legend font size
    'figure.titlesize': 18 * scale_factor  # suptitle
})

joint_freq_bins = [f'$10^{i}$' for i in range(6+1)]
subject_freq_bins = [f'$10^{i+1}$' for i in range(6)]

# Create a figure and a set of subplots
plt.figure(figsize=(10, 8))

for model_name in model_name_dict.keys():
    print('='*30)
    print('='*30)
    print('Model:', model_name)

    try:
        data = jsonlines.open(f'../../../results/{dataset_name}/{model_name}_{dataset_name}_{training_type}/pred_{dataset_name}_{dataset_type}.jsonl')
    except:
        raise Exception
        # continue

    if 'gpt' in model_name:
        coo_matrix = pile_coo_matrix
        num_total_samples = 254188957
    else:
        coo_matrix = bert_coo_matrix
        num_total_samples = 158887337

    openai_api = True if 'gpt-3.5-turbo' in model_name or 'gpt-4-0125' in model_name else False

    results_hits_1, results_hits_10, results_hits_100 = defaultdict(list), defaultdict(list), defaultdict(list)
    rel_results_hits_1, rel_results_hits_10, rel_results_hits_100 = defaultdict(dict), defaultdict(dict), defaultdict(dict)

    for pred in tqdm(data.iter()):
        subj = uid_subj_map[pred['uid']]
        rel = uid_rel_map[pred['uid']]
        obj = uid_obj_map[pred['uid']]
        subj = ' '.join(text_normalization_without_lemmatization(subj))
        obj = ' '.join(text_normalization_without_lemmatization(obj))
        
        subj_count = coo_matrix.count(subj)
        obj_count = coo_matrix.count(obj)
        subj_obj_count = coo_matrix.coo_count(subj, obj)

        # skip if the count is -1 (unknown)
        if subj_obj_count < 0:
            continue

        subj_prob = subj_count / num_total_samples
        joint_prob = subj_obj_count / num_total_samples
        cond_prob = subj_obj_count / subj_count if subj_count > 0 else 0

        joint_freq = subj_obj_count
        joint_section = frequency_to_section(joint_freq)

        subj_freq = subj_count
        subj_section = frequency_to_section(subj_freq)

        section = f'{joint_section}_{subj_section}'

        results_hits_1[section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            results_hits_10[section].append(pred['hits@10_remove_stopwords'])
            results_hits_100[section].append(pred['hits@100_remove_stopwords'])

        if section not in rel_results_hits_1[rel]:
            rel_results_hits_1[rel][section] = []
            rel_results_hits_10[rel][section] = []
            rel_results_hits_100[rel][section] = []
        rel_results_hits_1[rel][section].append(pred['hits@1_remove_stopwords'])
        if not openai_api:
            rel_results_hits_10[rel][section].append(pred['hits@10_remove_stopwords'])
            rel_results_hits_100[rel][section].append(pred['hits@100_remove_stopwords'])

    num_samples = {}
    joint_sections = range(1, len(bins)+1)
    subj_sections = range(1, len(bins)+1)
    for joint_section in joint_sections:
        for subj_section in subj_sections:
            section = f'{joint_section}_{subj_section}'
            num_samples[section] = len(results_hits_1[section])

            if section in results_hits_1:
                results_hits_1[section] = np.mean(results_hits_1[section]), np.std(results_hits_1[section])
                results_hits_10[section] = np.mean(results_hits_10[section]), np.std(results_hits_10[section])
                results_hits_100[section] = np.mean(results_hits_100[section]), np.std(results_hits_100[section])

    result = {}
    for joint_section in joint_sections:
        for subj_section in subj_sections:
            section = f'{joint_section}_{subj_section}'
            if section in results_hits_1:
                result[f'hits@1_remove_stopwords_section_{section}'] = f'%.2f +- %.2f' % results_hits_1[section]

    for joint_section in joint_sections:
        for subj_section in subj_sections:
            section = f'{joint_section}_{subj_section}'
            if section in results_hits_100:
                result[f'hits@100_remove_stopwords_section_{section}'] = f'%.2f +- %.2f' % results_hits_100[section]

    print(num_samples)
    # print(json.dumps(result, indent=4))

    hits_100_mean = [[results_hits_100[f'{joint_section}_{subj_section}'][0] for joint_section in joint_sections] for subj_section in subj_sections]
    hits_100_std = [[results_hits_100[f'{joint_section}_{subj_section}'][1] for joint_section in joint_sections] for subj_section in subj_sections]

    data = np.array(hits_100_mean)

    mask = np.ones_like(data.T, dtype='bool')
    mask[np.triu_indices_from(mask)] = False
    mask = np.rot90(mask, 1)

    data = np.flipud(data)

    ax = sns.heatmap(data, mask=mask, annot=True, fmt=".2f", linewidth=0.5, cmap='Blues',
                     cbar_kws={'label': 'Hits@100'})
    ax.set_facecolor("white")
    
# Rotate the tick labels for clarity
plt.xticks(range(len(joint_freq_bins)), joint_freq_bins, rotation=0, ha='right')
plt.yticks(range(len(subject_freq_bins)), subject_freq_bins[::-1], rotation=0)

# Set axis labels and title
plt.xlabel('Joint frequency of subject and object')
plt.ylabel('Subject frequency')

# Show the plot
# plt.title('Model Performance Comparison')
filename = f'results/{dataset_name}_{dataset_type}_{model_name}_{training_type}_hits@100_against_condprob.pdf'
plt.tight_layout()  # Adjust layout to fit all labels
plt.savefig(filename, format='pdf')
plt.show()