In [None]:
import os
from pathlib import Path
import pandas as pd
from collections import defaultdict
from utils.get_processed_dataset import get_processed_dataset
from utils.utils import read_jsonl, get_major_entities
from omegaconf import OmegaConf
from configs.config import PRONOUNS_GROUPS, PLURAL_PRONOUNS, dataset_yaml, selected_keys, entity_gender_metadata, pronoun_dialogue_metadata
from configs.config_gen import NAME_TO_PREFIX
from tqdm.auto import tqdm
import hydra
from utils.qa_utils import write_qa_to_jsonl, get_mention_info, add_copelands_count
import dtale
import jsonlines
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from hydra.core.global_hydra import GlobalHydra
import regex as re

In [None]:
os.chdir(Path(__file__).resolve().parents[1])

In [48]:
def wrap_df_to_heatmap(df, title):
    columns = df.columns
    unique_labels = sorted(set([re.split(r' - ', col)[0] for col in columns]))
    unique_labels.remove("model_name")
    for row_ind, row in df.iterrows():
        ## Access model_name key of the row
        model_name = row["model_name"]
        
        confusion_matrix = pd.DataFrame(0, index=unique_labels, columns=unique_labels)
    
        # Populate the confusion matrix from the row data
        for col in columns:
            if col == "model_name":
                continue
            option_from, option_to = re.split(r' - ', col)
            confusion_matrix.at[option_from, option_to] = row[col]

        plt.figure(figsize=(10, 7))
        sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="YlGnBu")
        plt.title(f"{title} - {model_name}")
        plt.show()

In [None]:
## Find good examples?

In [49]:
accuracy_path = "analysis_data/summary/summary_accuracy.xlsx"
count_path = "analysis_data/summary/summary_count.xlsx"
gender_info_path = "analysis_data/gender/gender_info.xlsx"
gender_unnested_info_path = "analysis_data/gender/gender_unnested_info.xlsx"
options_path = "analysis_data/options/options_info.xlsx"
nested_path = "analysis_data/nested/nested_info.xlsx"

In [50]:
# df = pd.read_excel(nested_path)
# df.rename(columns={"Unnamed: 0": "model_name"}, inplace=True)
# df.to_excel(nested_path, index=False)

# print(pd.read_excel(nested_path))

In [None]:
gender_unnested_df = pd.read_excel(gender_unnested_info_path)
wrap_df_to_heatmap(gender_unnested_df, "Gender -- Not Nested -- Info")

In [None]:
mention_info_path = "analysis_data/info/llama3_instruct_mention_info.csv"
mention_info_df = pd.read_csv(mention_info_path)
dtale.show(mention_info_df, subprocess=False, host='localhost')

In [None]:
random_accuracy_sum = 0.0
for row_ind, row in mention_info_df.iterrows():
    random_accuracy_sum += 1/row["num_options"]
random_accuracy = random_accuracy_sum/len(mention_info_df)
print("Random Accuracy: ", random_accuracy)

random_accuracy_sum = 0.0
mention_info_nom_df = mention_info_df[mention_info_df["category"] == "NOM"]
for row_ind, row in mention_info_nom_df.iterrows():
    random_accuracy_sum += 1/row["num_options"]
random_accuracy = random_accuracy_sum/len(mention_info_nom_df)
print("Random Accuracy for Nominals: ", random_accuracy)

random_accuracy_sum = 0.0
mention_info_pron_df = mention_info_df[mention_info_df["category"] == "PRON"]
for row_ind, row in mention_info_pron_df.iterrows():
    random_accuracy_sum += 1/row["num_options"]
random_accuracy = random_accuracy_sum/len(mention_info_pron_df)
print("Random Accuracy for Pronouns: ", random_accuracy)

In [None]:
## Number of documents
dev_qa_path = "data/qas/data/qas_dev.jsonl"
test_qa_path = "data/qas/data/qas_test.jsonl"
dataset_yaml = "datasets.yaml"
dataset = OmegaConf.load(dataset_yaml)
litbank_train_path = dataset.litbank.train_file
fantasy_train_path = dataset.fantasy.train_file

with jsonlines.open(litbank_train_path) as reader:
    litbank_train = [obj["doc_key"] for obj in reader]

with jsonlines.open(fantasy_train_path) as reader:
    fantasy_train = [obj["doc_key"] for obj in reader]

doc_keys = []
with jsonlines.open(dev_qa_path) as reader:
    dev_qas = list(reader)
    for qa in dev_qas:
        doc_keys.append(qa["doc_key"])
with jsonlines.open(test_qa_path) as reader:
    test_qas = list(reader)
    for qa in test_qas:
        doc_keys.append(qa["doc_key"])

print("Number of documents: ", len(set(doc_keys)))
litbank_count = 0
fantasy_count = 0
for doc_key in set(doc_keys):
    if doc_key in litbank_train:
        litbank_count += 1
    if doc_key in fantasy_train:
        fantasy_count += 1
print("Litbank documents: ", litbank_count)
print("Fantasy documents: ", fantasy_count)