In [135]:
import pandas as pd
import os
from sklearn.metrics import classification_report, accuracy_score

In [172]:
dataset = "RTE_Quant"
root_path = os.path.dirname(os.getcwd())
results_path = os.path.join(root_path, "data", "equate_labelled", f"{dataset}_gpt4.csv")

In [173]:
res = pd.read_csv(results_path)
res.head()

Unnamed: 0,sample_index,generated_label,error_message,golden_label,premise,hypothesis
0,89,entailment,,neutral,One of seven people killed in an explosion at ...,The teenage bomber behind a blast was one of t...
1,159,neutral,,neutral,The law once required settlers to clear 80 % o...,Before the late 1980 's some nations even gave...
2,108,neutral,,neutral,The deal comes some two weeks after a $ 16 bil...,The two firms sealed the deal despite what was...
3,118,neutral,,neutral,"General Dynamics , the maker of land combat sy...",General Dynamics loses $ 374m .
4,149,entailment,,entailment,Police in Rio de Janeiro arrested five men and...,"Millions of dollars of art were recovered , in..."


In [174]:
res.shape

(166, 6)

In [175]:
res[res.duplicated(subset=["sample_index"])].shape[0]

0

In [176]:
res = res.drop_duplicates(subset=['sample_index'], keep='first', ignore_index=True)
res.shape

(166, 6)

In [177]:
import re

def clean_text(text: str):
    return re.sub(r'\s+', ' ', text.lower().replace("\n", "")).strip()

In [178]:
for col in ["premise", "hypothesis"]:
    res[f"clean_{col}"] = res[col].apply(lambda text: clean_text(text))

# mark correctly classified samples and put them first in the dataset, so they don't get drop if they have a duplicate
res["keep"] = res.apply(lambda row: 0 if row["golden_label"] == row["generated_label"] else 1, axis=1)
res = res.sort_values(by=["keep"], ascending=True)
res = res[~res.duplicated(subset=['clean_premise', 'clean_hypothesis'])]
res.shape

(165, 9)

In [179]:
res.drop(columns=["keep"], axis=0, inplace=True)

In [180]:
res.to_csv(os.path.join(root_path, "data", "equate_labelled", f"cleaned_{dataset}_gpt4.csv"), index=False)

Do we have scripts that threw an error?

In [124]:
res[res['error_message'].notna()].shape[0]

9

In [125]:
list(res[res['error_message'].notna()]["sample_index"].values)

[6672, 6536, 7241, 5630, 3724, 5544, 3550, 5840, 4937]

Filter out erroneous scripts, if any

In [126]:
valid_res = res[res["error_message"].isna()]

In [127]:
valid_res["golden_label"].value_counts()

golden_label
entailment       2314
neutral          2313
contradiction    2302
Name: count, dtype: int64

In [128]:
valid_res["generated_label"].value_counts()

generated_label
entailment       3330
neutral          1806
contradiction    1793
Name: count, dtype: int64

In [129]:
accuracy_score(y_true=valid_res['golden_label'], y_pred=valid_res['generated_label'])

0.649155722326454

In [130]:
print(classification_report(y_true=valid_res['golden_label'], y_pred=valid_res['generated_label']))

               precision    recall  f1-score   support

contradiction       0.83      0.65      0.73      2302
   entailment       0.57      0.82      0.68      2314
      neutral       0.61      0.47      0.53      2313

     accuracy                           0.65      6929
    macro avg       0.67      0.65      0.65      6929
 weighted avg       0.67      0.65      0.65      6929



In [111]:
valid_res.sort_values(by="sample_index").head(10)

Unnamed: 0,sample_index,label,error_message,golden_label
134,0,neutral,,neutral
148,1,entailment,,neutral
8,2,entailment,,neutral
11,3,entailment,,entailment
140,4,neutral,,neutral
145,5,entailment,,neutral
5,6,entailment,,entailment
17,7,neutral,,neutral
39,8,entailment,,entailment
54,9,entailment,,entailment


In [112]:
misclassified_samples_indices = sorted(valid_res[valid_res["generated_label"] != valid_res["golden_label"]]["sample_index"].unique())
misclassified_samples_indices

[1,
 2,
 5,
 10,
 20,
 26,
 31,
 37,
 39,
 42,
 47,
 54,
 55,
 65,
 66,
 77,
 81,
 82,
 87,
 89,
 92,
 93,
 94,
 103,
 106,
 109,
 110,
 139,
 142,
 151,
 162]

AWPNLI:
approx issues: 58, 60, 500, 602, 614, 648, 686
ambiguity: 62, 82, 304, 305, 482
wrong label in equate: 107, 106, 109, 138, 550, 674
RedditNLI:
not sure how to correct: 24,
wrong label: 36
RTE_Quant:
wrong label: 26, 31
not sure how to correct: 142

In [113]:
len(misclassified_samples_indices)

31

In [155]:
scripts_path = os.path.join(root_path, "data", "generated", dataset, "script_with_cot_vars_first")
sample_indices = []
for script_file in os.listdir(scripts_path):
    if script_file.endswith(".py"):
        with open(os.path.join(scripts_path, script_file), 'r') as f:
            idx = int(script_file.split(".")[0].split("_")[-1])
            lines = f.readlines()
            inputs = "\n".join(lines[:3]).strip()
            # script = "\n".join(lines)
            if "more than" in inputs or "less than" in inputs:
                sample_indices.append(idx)
sorted(sample_indices)

[0,
 1,
 2,
 3,
 4,
 6,
 7,
 8,
 9,
 10,
 12,
 13,
 14,
 15,
 16,
 18,
 19,
 20,
 21,
 22,
 24,
 26,
 28,
 29,
 30,
 32,
 34,
 35,
 36,
 38,
 39,
 40,
 41,
 42,
 44,
 46,
 47,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 60,
 62,
 64,
 66,
 67,
 68,
 70,
 71,
 72,
 73,
 74,
 76,
 78,
 79,
 80,
 82,
 84,
 86,
 88,
 90,
 92,
 93,
 94,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 106,
 108,
 110,
 111,
 112,
 113,
 116,
 118,
 119,
 120,
 121,
 123,
 124,
 126,
 128,
 130,
 132,
 133,
 134,
 135,
 136,
 138,
 139,
 140,
 142,
 143,
 144,
 146,
 148,
 149,
 150,
 152,
 154,
 155,
 156,
 157,
 158,
 160,
 161,
 162,
 164,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 176,
 177,
 178,
 180,
 182,
 183,
 184,
 186,
 187,
 188,
 190,
 192,
 194,
 195,
 196,
 197,
 198,
 200,
 201,
 202,
 203,
 204,
 205,
 206,
 208,
 209,
 210,
 212,
 214,
 216,
 217,
 218,
 219,
 220,
 221,
 222,
 224,
 226,
 227,
 228,
 230,
 231,
 232,
 234,
 236,
 237,
 238,
 240,
 242,
 243,
 244,
 245,
 2

In [158]:
misclassified_with_quantifier = [idx for idx in misclassified_samples_indices if idx in sample_indices]
len(misclassified_with_quantifier)

130

In [159]:
misclassified_with_quantifier

[1,
 7,
 13,
 15,
 21,
 29,
 35,
 39,
 47,
 55,
 57,
 67,
 68,
 73,
 93,
 96,
 97,
 99,
 101,
 103,
 111,
 113,
 119,
 123,
 124,
 126,
 133,
 135,
 143,
 149,
 155,
 161,
 171,
 173,
 177,
 183,
 190,
 195,
 197,
 203,
 205,
 209,
 217,
 221,
 227,
 237,
 243,
 249,
 259,
 267,
 269,
 271,
 273,
 281,
 283,
 287,
 295,
 311,
 321,
 325,
 327,
 331,
 335,
 337,
 340,
 347,
 368,
 371,
 373,
 383,
 385,
 389,
 392,
 393,
 405,
 415,
 427,
 431,
 439,
 451,
 452,
 473,
 477,
 489,
 491,
 495,
 499,
 501,
 505,
 506,
 509,
 511,
 525,
 527,
 533,
 549,
 577,
 579,
 585,
 589,
 595,
 598,
 601,
 602,
 605,
 611,
 613,
 623,
 629,
 641,
 647,
 649,
 653,
 661,
 663,
 669,
 680,
 683,
 685,
 691,
 693,
 699,
 1168,
 1483,
 2386,
 2727,
 3937,
 4319,
 4873,
 4945]