In [38]:
import json
import re

def read_jsonl(file_name):
    with open(file_name, "r") as file:
        tokens = []
        for line in file:
            tokens += [json.loads(line)]

    return tokens

def load_data(file_name):
    data = {}
    with open(file_name, "r") as file:
        for line in file:
            aux_dict = json.loads(line)
            dict_result = {
                "id": aux_dict["id"],
                "label": aux_dict["result"]["label"],
                "prob": aux_dict["result"]["prob"],
                "all_prob": aux_dict["result"]["all_prob"],
            }
            data[dict_result["id"]] = dict_result
            
    return data

models = [
    "cross_encoder_nli_distilroberta_base",
    "cross_encoder_nli_MiniLM2_L6_H768",
    "ynie_bart_large_snli_mnli_fever_anli_R1_R2_R3_nli"
]

data_paths = [
    "../data/multinli_1.0/multinli_1.0_dev_mismatched.jsonl",
    "../data/multinli_1.0/multinli_1.0_dev_matched.jsonl",
    "../data/snli_1.0/snli_1.0_test.jsonl",
]

MODEL_MAP = {
    "cross_encoder_nli_distilroberta_base": "roberta",
    "cross_encoder_nli_MiniLM2_L6_H768": "minilm",
    "ynie_bart_large_snli_mnli_fever_anli_R1_R2_R3_nli": "bart"
}

In [39]:
undertrained_presence = {
    "dataset": [],
    "model": [],
    "sentence_idx": [],
    "shared_tokens": [],
    "sentence": [],
    "undertrained_tokens": [],
    "original_label": [],
    "predicted_label": []
}

for data_path in data_paths:
    data = {
        "label": [],
        "sent1": [],
        "sent2": [],
        "sent_combined": []
    }
    dataset_name = data_path.split("/")[-1][:-6]
    with open(data_path, "r") as file:
        for line in file:
            aux_dict = json.loads(line)
            data["label"] += [aux_dict["gold_label"]]
            data["sent1"] += [aux_dict["sentence1"]] # premise
            data["sent2"] += [aux_dict["sentence2"]] # hypothesis
            data["sent_combined"] += [aux_dict["sentence1"] + " " + aux_dict["sentence2"]]
    results_file = f"../results/{MODEL_MAP[model]}/{dataset_name}-standard-{MODEL_MAP[model]}.jsonl"
    predicted_results = load_data(results_file)
    for model in models:
        file_name = f"undertrained_{model}.jsonl"    
        tokens = read_jsonl(file_name)
        tokens = set([tok["decoded"] for tok in tokens])
        for idx, (sentence, original_label) in enumerate(zip(data["sent_combined"], data["label"])):
            sentence_text = sentence
            sentence = re.split(r"( ?[\w]+)([!\"\#$%&\'()*+,-\.\/:;<=>?@[\\\]^_`{|}~])?", sentence)
            sentence = set([tok.strip() for tok in sentence if tok not in  ("", " ", None)])
            if len(sentence & tokens) > 0:
                undertrained_presence["dataset"] += [dataset_name]
                undertrained_presence["sentence"] += [sentence_text]
                undertrained_presence["undertrained_tokens"] += [tokens]
                undertrained_presence["shared_tokens"] += [sentence & tokens]       
                undertrained_presence["sentence_idx"] += [idx] 
                undertrained_presence["model"] += [model] 
                undertrained_presence["original_label"] += [original_label]
                undertrained_presence["predicted_label"] += [predicted_results[idx]["label"]]
            

In [None]:
import pandas as pd

df = pd.DataFrame(undertrained_presence)

In [44]:
sum(df["original_label"] == df["predicted_label"])/len(df)

0.8322851153039832

In [49]:
for model in MODEL_MAP.keys():
    for label in df["predicted_label"].unique():
        aux_df = df[(df["model"] == model) & (df["original_label"] == label)]
        print(f"{MODEL_MAP[model]} for {label} accuracy={sum(aux_df['original_label'] == aux_df['predicted_label'])/len(aux_df)}")
    print("---------------")

roberta for contradiction accuracy=0.8235294117647058
roberta for neutral accuracy=0.8476190476190476
roberta for entailment accuracy=0.8715596330275229
---------------
minilm for contradiction accuracy=0.8648648648648649
minilm for neutral accuracy=0.926829268292683
minilm for entailment accuracy=0.85
---------------
bart for contradiction accuracy=0.7777777777777778
bart for neutral accuracy=0.9090909090909091
bart for entailment accuracy=0.7272727272727273
---------------
