# Evaluate on contract nli

In [23]:
import json
import os
from datasets import Dataset, DatasetDict

In [24]:
ROOT_PATH = ".."

id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}
label2id = {v: k for k, v in id2label.items()}

def contract_nli_iterator(data):
    documents, labels = data['documents'], data['labels']
    for document in documents:
        id = document['id']
        file_name = document['file_name']
        text = document['text']
        spans = document['spans']
        annotation_sets = document['annotation_sets']
        document_type = document['document_type']
        url = document['url']
        for annotation_id, annotation_content in annotation_sets[0]['annotations'].items():
            hypothesis = labels[annotation_id]['hypothesis']
            choice = annotation_content['choice']
            yield {
                "id": id,
                "file_name": file_name,
                "text": text,
                "spans": spans,
                "document_type": document_type,
                "url": url,
                "hypothesis": hypothesis,
                "labels": label2id[choice],
            }            
base_filepath = os.path.join(ROOT_PATH, "ignored_dir/data/contract-nli")
train_filepath = os.path.join(base_filepath, "train.json")
validation_filepath = os.path.join(base_filepath, "dev.json")
test_filepath = os.path.join(base_filepath, "test.json")
with open(train_filepath) as f:
    train_data = json.load(f)
with open(validation_filepath) as f:
    validation_data = json.load(f)
with open(test_filepath) as f:
    test_data = json.load(f)
data = {
    "train": Dataset.from_generator(lambda: contract_nli_iterator(train_data)),
    "validation": Dataset.from_generator(lambda: contract_nli_iterator(validation_data)),
    "test": Dataset.from_generator(lambda: contract_nli_iterator(test_data)),
}
contract_nli_dataset = DatasetDict(data)

### Evaluate vanilla chunking on contract nli

In [25]:
contract_nli_dataset_validation = contract_nli_dataset['validation']

In [26]:
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter
from transformers import AutoTokenizer

In [27]:
for e in contract_nli_dataset_validation:
    print(e.keys())
    print('id', e['id'])
    print('file-name', e['file_name'])
    print('text', 'omitted') # e['text'])
    print('spans', e['spans'])
    print('document_type', e['document_type'])
    print('url', e['url'])
    print('hypothesis', e['hypothesis'])
    print('labels', e['labels'])
    break

dict_keys(['id', 'file_name', 'text', 'spans', 'document_type', 'url', 'hypothesis', 'labels'])
id 3
file-name 09-24-2019-04-25-05-3914910473.pdf
text omitted
spans [[0, 14], [15, 67], [68, 127], [128, 282], [283, 362], [362, 455], [455, 492], [493, 559], [559, 653], [653, 690], [691, 793], [794, 801], [802, 1278], [1278, 1358], [1359, 1383], [1384, 1416], [1417, 1490], [1491, 1907], [1908, 2178], [2178, 2475], [2475, 2638], [2639, 2795], [2796, 2927], [2928, 3074], [3075, 3088], [3089, 3202], [3203, 3344], [3345, 3544], [3545, 3725], [3726, 3932], [3933, 3998], [3999, 4125], [4126, 4171], [4172, 4214], [4215, 4311], [4312, 4457], [4458, 4569], [4570, 4738], [4739, 4867], [4868, 4965], [4966, 5059], [5060, 5092], [5093, 5298], [5298, 5743], [5743, 5857], [5858, 5895], [5896, 6027], [6027, 6271], [6271, 6610], [6610, 6749], [6750, 6770], [6771, 6949], [6949, 7020], [7020, 7162], [7163, 7179], [7180, 7351], [7351, 7556], [7557, 7566], [7567, 7939], [7940, 7954], [7955, 8100], [8100, 8222

In [28]:

tmp = None
for e in contract_nli_dataset_validation:
    tmp = e
    break

facebook_bart_large_mnli_autotokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
documents = [Document(text=e['text'])]
parser = SentenceSplitter(chunk_size=200, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
nodes = parser.get_nodes_from_documents(documents, show_progress=True)

Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2270 > 1024). Running this sequence through the model will result in indexing errors


In [29]:
for e in nodes:
    print(e.text)
    break

OISAIR PROJECT
TWO-WAY CONFIDENTIALITY AND NON-DISCLOSURE AGREEMENT
(TO BE SIGNED ELECTRONICALLY THROUGH THE INNOVAIR PLATFORM)
This Confidentiality and Non-Disclosure Agreement (hereinafter referred to as the “Agreement”) dated ………………………. (“Effective Date”) is made by and between:
1) <Research institution name> with registered offices located in ……………………………, Tax registration No ………, represented by ………………………………….., in the legal capacity as …………………….. Hereinafter referred to as “………………..”
2) <Company name> with registered offices located in ………………………….. Tax registration No.


In [30]:
import torch
from transformers import AutoModelForSequenceClassification

In [31]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"using device {device}")
facebook_bart_large_mnli_automodel_for_sequence_classification = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli').to(device)

using device cuda


In [32]:
# prop = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[law_node.text, contract_node.text] for law_node in law_nodes_chunk], return_tensors='pt', padding='longest', truncation='only_first')['input_ids']


tokens = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[nodes[0].text, "haha"]], return_tensors='pt', padding='longest', truncation='only_first')['input_ids']
tokens = torch.tensor(tokens)
ret = facebook_bart_large_mnli_automodel_for_sequence_classification(tokens.to(device))

  tokens = torch.tensor(tokens)


In [33]:
print(ret.logits)

tensor([[ 1.0773,  0.5191, -1.5044]], device='cuda:0',
       grad_fn=<AddmmBackward0>)


In [34]:
import torch.nn as nn

In [35]:
test = torch.tensor([[1, 2], [-1, 2]], dtype=torch.float64)
nn.Softmax(dim=1)(test)
torch.mean(test, dim=0)

tensor([0., 2.], dtype=torch.float64)

In [36]:
from tqdm import tqdm

In [37]:
bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
bart_large_mnli_label2id = {v: k for k, v in bart_large_mnli_id2labal.items()}

In [38]:
def bart_label_to_contract_nli_label(id):
    if id == 0: return 1
    elif id == 1: return 2
    elif id == 2: return 0
    else: return None

In [39]:
'''
parser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)

chunk_record = dict()
softmax_f = nn.Softmax(dim=1)
averaged_probs_record = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_validation))])
accuracy_record = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])
with torch.no_grad():
    for i, e in tqdm(enumerate(contract_nli_dataset_validation), total=len(contract_nli_dataset_validation)):
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = chunk_record.get(e['file_name'], None)
        if premise_chunks is None:
            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])
            premise_chunks = [e.text for e in premise_chunks]
            chunk_record[e['file_name']] = premise_chunks
        e_tokens = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[premise_chunk, hypothesis] for premise_chunk in premise_chunks], return_tensors='pt', padding='longest', truncation='only_first')
        e_tokens = e_tokens['input_ids'].detach().cpu()
        logits = facebook_bart_large_mnli_automodel_for_sequence_classification(e_tokens.to(device)).logits
        averaged_probs = torch.mean(softmax_f(logits), dim=0)
        averaged_probs_record[i] = averaged_probs
        ans = torch.argmax(averaged_probs)
        label = e['labels']
        accuracy_record[i] = 1 if label == bart_label_to_contract_nli_label(ans) else 0
print(torch.mean(accuracy_record))
'''

"\nparser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])\n#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)\n\nchunk_record = dict()\nsoftmax_f = nn.Softmax(dim=1)\naveraged_probs_record = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_validation))])\naccuracy_record = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])\nwith torch.no_grad():\n    for i, e in tqdm(enumerate(contract_nli_dataset_validation), total=len(contract_nli_dataset_validation)):\n        premise = e['text']\n        hypothesis = e['hypothesis']\n        premise_chunks = chunk_record.get(e['file_name'], None)\n        if premise_chunks is None:\n            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])\n            premise_chunks = [e.text for e in premise_chunks]\n            chunk_record[e['file_name']] = premise_chunks\

In [40]:
"""
Results analysis: Turn into two categories.
"""

'''
# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])
contradiction_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_validation):
    averaged_probs = averaged_probs_record[i]
    ans = 1 if averaged_probs[0] > 0.9 else 0
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1
    # ans = bart_label_to_contract_nli_label(ans)
    ans = 1 if ans == 1 else 0
    label = 1 if label == 1 else 0
    tmp[i] = 1 if ans == label else 0
    if label == 1:
        contradiction_cnt += 1
print(torch.mean(tmp))
print(contradiction_cnt, len(contract_nli_dataset_validation))
print(label_sum)
'''

'\n# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}\n# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}\n\ntmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])\ncontradiction_cnt = 0\nlabel_sum = dict()\nfor i, e in enumerate(contract_nli_dataset_validation):\n    averaged_probs = averaged_probs_record[i]\n    ans = 1 if averaged_probs[0] > 0.9 else 0\n    # ans = torch.argmax(averaged_probs)\n    label = e[\'labels\']\n    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1\n    # ans = bart_label_to_contract_nli_label(ans)\n    ans = 1 if ans == 1 else 0\n    label = 1 if label == 1 else 0\n    tmp[i] = 1 if ans == label else 0\n    if label == 1:\n        contradiction_cnt += 1\nprint(torch.mean(tmp))\nprint(contradiction_cnt, len(contract_nli_dataset_validation))\nprint(label_sum)\n'

In [41]:
"""
Results analysis: Set threshold for neutral as well.
"""

'''

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])
contradiction_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_validation):
    averaged_probs = averaged_probs_record[i]
    ans = 1 if averaged_probs[0] > 0.9 else (0 if averaged_probs[1] > 0.9 else 2)
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    tmp[i] = 1 if ans == label else 0
print(torch.mean(tmp))
'''

'\n\n# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}\n# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}\n\ntmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_validation))])\ncontradiction_cnt = 0\nlabel_sum = dict()\nfor i, e in enumerate(contract_nli_dataset_validation):\n    averaged_probs = averaged_probs_record[i]\n    ans = 1 if averaged_probs[0] > 0.9 else (0 if averaged_probs[1] > 0.9 else 2)\n    # ans = torch.argmax(averaged_probs)\n    label = e[\'labels\']\n    tmp[i] = 1 if ans == label else 0\nprint(torch.mean(tmp))\n'

In [42]:
"""
Try on test set. Test set should be used for testing?
"""

'\nTry on test set. Test set should be used for testing?\n'

In [43]:
contract_nli_dataset_test = contract_nli_dataset['test']

In [44]:
"""
Try all data to see if results persist. 
"""
parser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)

softmax_f = nn.Softmax(dim=1)
chunk_record_test = dict()
averaged_probs_record_test = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_test))])
accuracy_record_test = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
with torch.no_grad():
    for i, e in tqdm(enumerate(contract_nli_dataset_test), total=len(contract_nli_dataset_test)):
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = chunk_record_test.get(e['file_name'], None)
        if premise_chunks is None:
            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])
            premise_chunks = [e.text for e in premise_chunks]
            chunk_record_test[e['file_name']] = premise_chunks
        e_tokens = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[premise_chunk, hypothesis] for premise_chunk in premise_chunks], return_tensors='pt', padding='longest', truncation='only_first')
        e_tokens = e_tokens['input_ids'].detach().cpu()
        logits = facebook_bart_large_mnli_automodel_for_sequence_classification(e_tokens.to(device)).logits
        averaged_probs = torch.mean(softmax_f(logits), dim=0)
        averaged_probs_record_test[i] = averaged_probs
        ans = torch.argmax(averaged_probs)
        label = e['labels']
        accuracy_record_test[i] = 1 if label == bart_label_to_contract_nli_label(ans) else 0
print(torch.mean(accuracy_record_test))

100%|██████████| 2091/2091 [11:06<00:00,  3.14it/s]

tensor(0.3128)





In [51]:
### Get confusion matrix

from sklearn.metrics import confusion_matrix

labels = []
preds = []

for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = averaged_probs_record_test[i]
    ans = torch.argmax(averaged_probs)
    ans = bart_label_to_contract_nli_label(ans)
    preds.append(ans)
    label = e['labels']
    labels.append(label)

conf_mat = confusion_matrix(labels, preds)
print(conf_mat)

[[181 341 446]
 [ 14 117  89]
 [ 80 467 356]]


In [55]:
### save the data
import pickle

save_dir_path = "../ignored_dir/results/vanilla_chunking_on_contract_nli"
if not os.path.exists(save_dir_path):
    os.mkdir(save_dir_path)
readme_path = os.path.join(save_dir_path, "README.txt")
with open(readme_path, "w") as f:
    f.write('{0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}')
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")
conf_mat_filehandler = open(conf_mat_path, 'wb') 
pickle.dump(conf_mat, conf_mat_filehandler)
conf_mat_filehandler.close()

In [2]:
### Test 

import os
import pickle

save_dir_path = "../ignored_dir/results/vanilla_chunking_on_contract_nli"
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")

with open(conf_mat_path, 'rb') as f:
    conf_mat_loaded = pickle.load(f)
print(conf_mat_loaded) 

[[181 341 446]
 [ 14 117  89]
 [ 80 467 356]]


In [3]:
### Make table
conf_mat = conf_mat_loaded

In [9]:
accuracy = sum([conf_mat[i][i] for i in range(3)]) / sum(sum(conf_mat))
print(f"accuracy: {accuracy}")

accuracy: 0.3127690100430416


In [5]:
id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}
labels = [id2label[i] for i in range(3)]
print(labels)

['Entailment', 'Contradiction', 'NotMentioned']


In [6]:
latex_table = """\\begin{center}
\\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\\\
\hline
\multirow{4}{2em}{true} &  & """ + f"{labels[0]} & {labels[1]} & {labels[2]} \\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[0]} & {conf_mat[0][0]} & {conf_mat[0][1]} & {conf_mat[0][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[1]} & {conf_mat[1][0]} & {conf_mat[1][1]} & {conf_mat[1][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[2]} & {conf_mat[2][0]} & {conf_mat[2][1]} & {conf_mat[2][2]}\\\\\n" \
+ """\hline
\end{tabular}
\end{center}"""

In [8]:
print(latex_table)

\begin{center}
\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\
\hline
\multirow{4}{2em}{true} &  & Entailment & Contradiction & NotMentioned \\
\cline{2-5}
& Entailment & 181 & 341 & 446\\
\cline{2-5}
& Contradiction & 14 & 117 & 89\\
\cline{2-5}
& NotMentioned & 80 & 467 & 356\\
\hline
\end{tabular}
\end{center}


In [7]:
"""
Results analysis: Turn into two categories.
"""

"""
from sklearn.metrics import confusion_matrix

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
contradiction_cnt = 0
label_sum = dict()
labels_test = []
preds_test = []
for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = averaged_probs_record_test[i]
    ans = 1 if averaged_probs[0] > 0.9 else 0
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1
    # ans = bart_label_to_contract_nli_label(ans)
    ans = 1 if ans == 1 else 0
    label = 1 if label == 1 else 0
    tmp[i] = 1 if ans == label else 0
    labels_test.append(label)
    preds_test.append(ans)
    if label == 1:
        contradiction_cnt += 1
conf_mat_test = confusion_matrix(labels_test, preds_test)
print("two category result")
print(torch.mean(tmp))
print(contradiction_cnt, len(contract_nli_dataset_test))
print(label_sum)
print(conf_mat_test)
"""

"""
Results analysis: Set threshold for neutral as well.
"""

"""

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
contradiction_cnt = 0
label_sum = dict()
labels_test = []
preds_test = []
for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = averaged_probs_record_test[i]
    ans = bart_label_to_contract_nli_label(ans)
    label = e['labels']
    labels_test.append(label)
    preds_test.append(ans)
    tmp[i] = 1 if ans == label else 0
print("three catogory with thresholding result")
print(torch.mean(tmp))
"""

KeyboardInterrupt: 

In [46]:
# import datasets

In [47]:
# contract_nli_dataset_all = datasets.concatenate_datasets([contract_nli_dataset['train'], contract_nli_dataset['validation'], contract_nli_dataset['test']])

In [48]:
"""
Try all data to see if results persist. 
"""

'''
parser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)

chunk_record_all = dict()
softmax_f_all = nn.Softmax(dim=1)
averaged_probs_record_all = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_all))])
accuracy_record_all = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
with torch.no_grad():
    for i, e in tqdm(enumerate(contract_nli_dataset_all), total=len(contract_nli_dataset_all)):
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = chunk_record_all.get(e['file_name'], None)
        if premise_chunks is None:
            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])
            premise_chunks = [e.text for e in premise_chunks]
            chunk_record_all[e['file_name']] = premise_chunks
        e_tokens = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[premise_chunk, hypothesis] for premise_chunk in premise_chunks], return_tensors='pt', padding='longest', truncation='only_first')
        e_tokens = e_tokens['input_ids'].detach().cpu()
        logits = facebook_bart_large_mnli_automodel_for_sequence_classification(e_tokens.to(device)).logits
        averaged_probs = torch.mean(softmax_f(logits), dim=0)
        averaged_probs_record_all[i] = averaged_probs
        ans = torch.argmax(averaged_probs)
        label = e['labels']
        accuracy_record_all[i] = 1 if label == bart_label_to_contract_nli_label(ans) else 0
print(torch.mean(accuracy_record_all))
'''

"\nparser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])\n#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)\n\nchunk_record_all = dict()\nsoftmax_f_all = nn.Softmax(dim=1)\naveraged_probs_record_all = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_all))])\naccuracy_record_all = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])\nwith torch.no_grad():\n    for i, e in tqdm(enumerate(contract_nli_dataset_all), total=len(contract_nli_dataset_all)):\n        premise = e['text']\n        hypothesis = e['hypothesis']\n        premise_chunks = chunk_record_all.get(e['file_name'], None)\n        if premise_chunks is None:\n            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])\n            premise_chunks = [e.text for e in premise_chunks]\n            chunk_record_all[e['file_name']] = premise_chunks\n   

In [49]:
"""
Results analysis: Turn into two categories.
"""

'''

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
contradiction_cnt = 0
contradiction_correct_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_all):
    averaged_probs = averaged_probs_record_all[i]
    ans = 1 if averaged_probs[0] > 0.6 else 0
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1
    # ans = bart_label_to_contract_nli_label(ans)
    ans = 1 if ans == 1 else 0
    label = 1 if label == 1 else 0
    tmp[i] = 1 if ans == label else 0
    if label == 1:
        contradiction_cnt += 1
    if ans == label == 1:
        contradiction_correct_cnt += 1
print(torch.mean(tmp))
print(f"contradiction correct cnt: {contradiction_correct_cnt}")
print(f"contradiction cnt: {contradiction_cnt}")
print(f"contradiction hit rate: {contradiction_correct_cnt / contradiction_cnt}")
print(label_sum)

"""
Results analysis: Set threshold for neutral as well.
"""

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
contradiction_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_all):
    averaged_probs = averaged_probs_record_all[i]
    ans = 1 if averaged_probs[0] > 0.9 else (0 if averaged_probs[1] > 0.9 else 2)
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    tmp[i] = 1 if ans == label else 0
print(torch.mean(tmp))
'''

'\n\n# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}\n# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}\n\ntmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])\ncontradiction_cnt = 0\ncontradiction_correct_cnt = 0\nlabel_sum = dict()\nfor i, e in enumerate(contract_nli_dataset_all):\n    averaged_probs = averaged_probs_record_all[i]\n    ans = 1 if averaged_probs[0] > 0.6 else 0\n    # ans = torch.argmax(averaged_probs)\n    label = e[\'labels\']\n    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1\n    # ans = bart_label_to_contract_nli_label(ans)\n    ans = 1 if ans == 1 else 0\n    label = 1 if label == 1 else 0\n    tmp[i] = 1 if ans == label else 0\n    if label == 1:\n        contradiction_cnt += 1\n    if ans == label == 1:\n        contradiction_correct_cnt += 1\nprint(torch.mean(tmp))\nprint(f"contradiction correct cnt: {contradiction_correct_cnt}")\nprint(f"contradiction cnt: {contradic