# Evaluate semantic chunking on contract nli

In [1]:
from datasets import Dataset, DatasetDict
import json
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
ROOT_PATH = ".."

id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}
label2id = {v: k for k, v in id2label.items()}

def contract_nli_iterator(data):
    documents, labels = data['documents'], data['labels']
    for document in documents:
        id = document['id']
        file_name = document['file_name']
        text = document['text']
        spans = document['spans']
        annotation_sets = document['annotation_sets']
        document_type = document['document_type']
        url = document['url']
        for annotation_id, annotation_content in annotation_sets[0]['annotations'].items():
            hypothesis = labels[annotation_id]['hypothesis']
            choice = annotation_content['choice']
            yield {
                "id": id,
                "file_name": file_name,
                "text": text,
                "spans": spans,
                "document_type": document_type,
                "url": url,
                "hypothesis": hypothesis,
                "labels": label2id[choice],
            }            
base_filepath = os.path.join(ROOT_PATH, "ignored_dir/data/contract-nli")
train_filepath = os.path.join(base_filepath, "train.json")
validation_filepath = os.path.join(base_filepath, "dev.json")
test_filepath = os.path.join(base_filepath, "test.json")
with open(train_filepath) as f:
    train_data = json.load(f)
with open(validation_filepath) as f:
    validation_data = json.load(f)
with open(test_filepath) as f:
    test_data = json.load(f)
data = {
    "train": Dataset.from_generator(lambda: contract_nli_iterator(train_data)),
    "validation": Dataset.from_generator(lambda: contract_nli_iterator(validation_data)),
    "test": Dataset.from_generator(lambda: contract_nli_iterator(test_data)),
}
contract_nli_dataset = DatasetDict(data)

### Vanilla Chunking Methods

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"using device {device}")
vanilla_chunking_facebook_bart_large_mnli_autotokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
vanilla_chunking_facebook_bart_large_mnli_automodel_for_sequence_classification = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli').to(device)
vanilla_chunking_bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
vanilla_chunking_bart_large_mnli_label2id = {v: k for k, v in vanilla_chunking_bart_large_mnli_id2labal.items()}
vanilla_chunking_parser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: vanilla_chunking_facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
vanilla_chunking_softmax_f = nn.Softmax(dim=1)

def calculate_prob_with_vanilla_chunking(premise, hypothesis):
    ans = None
    with torch.no_grad():
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = vanilla_chunking_parser.get_nodes_from_documents([Document(text=premise)])
        premise_chunks = [e.text for e in premise_chunks]
        e_tokens = vanilla_chunking_facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[premise_chunk, hypothesis] for premise_chunk in premise_chunks], \
                                                                                             return_tensors='pt', padding='longest', truncation='only_first')
        e_tokens = e_tokens['input_ids'].detach().cpu()
        logits = vanilla_chunking_facebook_bart_large_mnli_automodel_for_sequence_classification(e_tokens.to(device)).logits
        averaged_probs = torch.mean(vanilla_chunking_softmax_f(logits), dim=0)
        # ans = torch.argmax(averaged_probs)
        ans = averaged_probs
    return ans

using device cuda:0


### Evaluate vanilla chunking on contract nli

In [4]:
contract_nli_dataset_test = contract_nli_dataset['test']

In [5]:
print("loading semantic chunking embedding model to cuda:1")
semantic_chunking_embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", device='cuda:1')

loading semantic chunking embedding model to cuda:1


In [6]:
semantic_chunking_splitter = SemanticSplitterNodeParser(
    buffer_size=1, breakpoint_percentile_threshold=95, embed_model=semantic_chunking_embed_model
)

In [7]:
def semantic_chunking_bart_label_to_contract_nli_label(id):
    if id == 0: return 1
    elif id == 1: return 2
    elif id == 2: return 0
    else: return None

semantic_chunking_softmax_f = nn.Softmax(dim=1)
semantic_chunking_chunk_record_test = dict()
semantic_chunking_averaged_probs_record_test = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_test))])
semantic_chunking_accuracy_record_test = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
with torch.no_grad():
    for i, e in tqdm(enumerate(contract_nli_dataset_test), total=len(contract_nli_dataset_test)):
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = semantic_chunking_chunk_record_test.get(e['file_name'], None)
        if premise_chunks is None:
            premise_chunks = semantic_chunking_splitter.get_nodes_from_documents([Document(text=premise)])
            premise_chunks = [e.text for e in premise_chunks]
            semantic_chunking_chunk_record_test[e['file_name']] = premise_chunks
        premise_chunk_prob_results = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(premise_chunks))])
        for j, premise_chunk in enumerate(premise_chunks):
            premise_chunk_prob = calculate_prob_with_vanilla_chunking(premise_chunk, hypothesis)
            premise_chunk_prob_results[j] = premise_chunk_prob
        averaged_probs = torch.mean(premise_chunk_prob_results, dim=0)
        semantic_chunking_averaged_probs_record_test[i] = averaged_probs
        ans = torch.argmax(averaged_probs)
        label = e['labels']
        semantic_chunking_accuracy_record_test[i] = 1 if label == semantic_chunking_bart_label_to_contract_nli_label(ans) else 0
print(torch.mean(semantic_chunking_accuracy_record_test))

  0%|          | 0/2091 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3207 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 2091/2091 [54:34<00:00,  1.57s/it] 

tensor(0.3128)





In [8]:
### Get confusion matrix

from sklearn.metrics import confusion_matrix

labels = []
preds = []

for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = semantic_chunking_averaged_probs_record_test[i]
    ans = torch.argmax(averaged_probs)
    ans = semantic_chunking_bart_label_to_contract_nli_label(ans)
    preds.append(ans)
    label = e['labels']
    labels.append(label)

conf_mat = confusion_matrix(labels, preds)
print(conf_mat)

[[181 341 446]
 [ 14 117  89]
 [ 80 467 356]]


In [9]:
### save the data
import pickle

save_dir_path = "../ignored_dir/results/semantic_chunking_on_contract_nli"
if not os.path.exists(save_dir_path):
    os.mkdir(save_dir_path)
readme_path = os.path.join(save_dir_path, "README.txt")
with open(readme_path, "w") as f:
    f.write('{0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}')
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")
conf_mat_filehandler = open(conf_mat_path, 'wb') 
pickle.dump(conf_mat, conf_mat_filehandler)
conf_mat_filehandler.close()

In [1]:
### Test load

import pickle
import os

save_dir_path = "../ignored_dir/results/semantic_chunking_on_contract_nli"
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")

with open(conf_mat_path, 'rb') as f:
    conf_mat_loaded = pickle.load(f)
print(conf_mat_loaded) 

[[181 341 446]
 [ 14 117  89]
 [ 80 467 356]]


In [2]:
### Make table
conf_mat = conf_mat_loaded

In [6]:
accuracy = sum([conf_mat[i][i] for i in range(3)]) / sum(sum(conf_mat))
print(f"accuracy: {accuracy}")

accuracy: 0.3127690100430416


In [3]:
id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}
labels = [id2label[i] for i in range(3)]
print(labels)

['Entailment', 'Contradiction', 'NotMentioned']


In [4]:
latex_table = """\\begin{center}
\\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\\\
\hline
\multirow{4}{2em}{true} &  & """ + f"{labels[0]} & {labels[1]} & {labels[2]} \\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[0]} & {conf_mat[0][0]} & {conf_mat[0][1]} & {conf_mat[0][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[1]} & {conf_mat[1][0]} & {conf_mat[1][1]} & {conf_mat[1][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[2]} & {conf_mat[2][0]} & {conf_mat[2][1]} & {conf_mat[2][2]}\\\\\n" \
+ """\hline
\end{tabular}
\end{center}"""

In [5]:
print(latex_table)

\begin{center}
\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\
\hline
\multirow{4}{2em}{true} &  & Entailment & Contradiction & NotMentioned \\
\cline{2-5}
& Entailment & 181 & 341 & 446\\
\cline{2-5}
& Contradiction & 14 & 117 & 89\\
\cline{2-5}
& NotMentioned & 80 & 467 & 356\\
\hline
\end{tabular}
\end{center}


In [15]:
"""
Results analysis: Turn into two categories.
"""

"""
from sklearn.metrics import confusion_matrix
"""

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

"""
tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
contradiction_cnt = 0
label_sum = dict()
labels_test = []
preds_test = []
for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = semantic_chunking_averaged_probs_record_test[i]
    ans = 1 if averaged_probs[0] > 0.9 else 0
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1
    # ans = bart_label_to_contract_nli_label(ans)
    ans = 1 if ans == 1 else 0
    label = 1 if label == 1 else 0
    tmp[i] = 1 if ans == label else 0
    labels_test.append(label)
    preds_test.append(ans)
    if label == 1:
        contradiction_cnt += 1
conf_mat_test = confusion_matrix(labels_test, preds_test)
print("two category result")
print(torch.mean(tmp))
print(contradiction_cnt, len(contract_nli_dataset_test))
print(label_sum)
print(conf_mat_test)
"""

"""
Results analysis: Set threshold for neutral as well.
"""

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

"""
tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_test))])
contradiction_cnt = 0
label_sum = dict()
labels_test = []
preds_test = []
for i, e in enumerate(contract_nli_dataset_test):
    averaged_probs = semantic_chunking_averaged_probs_record_test[i]
    ans = semantic_chunking_bart_label_to_contract_nli_label(ans)
    label = e['labels']
    labels_test.append(label)
    preds_test.append(ans)
    tmp[i] = 1 if ans == label else 0
print("three catogory with thresholding result")
print(torch.mean(tmp))
"""

two category result
tensor(0.8867)
220 2091
{'NotMentioned': 903, 'Entailment': 968, 'Contradiction': 220}
[[1854   17]
 [ 220    0]]
three catogory with thresholding result
tensor(0.3462)


In [16]:
# import datasets

In [17]:
# contract_nli_dataset_all = datasets.concatenate_datasets([contract_nli_dataset['train'], contract_nli_dataset['validation'], contract_nli_dataset['test']])

In [18]:
"""
Try all data to see if results persist. 
"""

'''
parser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])
#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)

chunk_record_all = dict()
softmax_f_all = nn.Softmax(dim=1)
averaged_probs_record_all = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_all))])
accuracy_record_all = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
with torch.no_grad():
    for i, e in tqdm(enumerate(contract_nli_dataset_all), total=len(contract_nli_dataset_all)):
        premise = e['text']
        hypothesis = e['hypothesis']
        premise_chunks = chunk_record_all.get(e['file_name'], None)
        if premise_chunks is None:
            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])
            premise_chunks = [e.text for e in premise_chunks]
            chunk_record_all[e['file_name']] = premise_chunks
        e_tokens = facebook_bart_large_mnli_autotokenizer.batch_encode_plus([[premise_chunk, hypothesis] for premise_chunk in premise_chunks], return_tensors='pt', padding='longest', truncation='only_first')
        e_tokens = e_tokens['input_ids'].detach().cpu()
        logits = facebook_bart_large_mnli_automodel_for_sequence_classification(e_tokens.to(device)).logits
        averaged_probs = torch.mean(softmax_f(logits), dim=0)
        averaged_probs_record_all[i] = averaged_probs
        ans = torch.argmax(averaged_probs)
        label = e['labels']
        accuracy_record_all[i] = 1 if label == bart_label_to_contract_nli_label(ans) else 0
print(torch.mean(accuracy_record_all))
'''

"\nparser = SentenceSplitter(chunk_size=300, chunk_overlap=50, tokenizer=lambda x: facebook_bart_large_mnli_autotokenizer(x)['input_ids'])\n#contract_nodes = parser.get_nodes_from_documents(contract_documents, show_progress=True)\n\nchunk_record_all = dict()\nsoftmax_f_all = nn.Softmax(dim=1)\naveraged_probs_record_all = torch.tensor([[0.0, 0.0, 0.0] for _ in range(len(contract_nli_dataset_all))])\naccuracy_record_all = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])\nwith torch.no_grad():\n    for i, e in tqdm(enumerate(contract_nli_dataset_all), total=len(contract_nli_dataset_all)):\n        premise = e['text']\n        hypothesis = e['hypothesis']\n        premise_chunks = chunk_record_all.get(e['file_name'], None)\n        if premise_chunks is None:\n            premise_chunks = parser.get_nodes_from_documents([Document(text=premise)])\n            premise_chunks = [e.text for e in premise_chunks]\n            chunk_record_all[e['file_name']] = premise_chunks\n   

In [19]:
"""
Results analysis: Turn into two categories.
"""

'''

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
contradiction_cnt = 0
contradiction_correct_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_all):
    averaged_probs = averaged_probs_record_all[i]
    ans = 1 if averaged_probs[0] > 0.6 else 0
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1
    # ans = bart_label_to_contract_nli_label(ans)
    ans = 1 if ans == 1 else 0
    label = 1 if label == 1 else 0
    tmp[i] = 1 if ans == label else 0
    if label == 1:
        contradiction_cnt += 1
    if ans == label == 1:
        contradiction_correct_cnt += 1
print(torch.mean(tmp))
print(f"contradiction correct cnt: {contradiction_correct_cnt}")
print(f"contradiction cnt: {contradiction_cnt}")
print(f"contradiction hit rate: {contradiction_correct_cnt / contradiction_cnt}")
print(label_sum)

"""
Results analysis: Set threshold for neutral as well.
"""

# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}
# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}

tmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])
contradiction_cnt = 0
label_sum = dict()
for i, e in enumerate(contract_nli_dataset_all):
    averaged_probs = averaged_probs_record_all[i]
    ans = 1 if averaged_probs[0] > 0.9 else (0 if averaged_probs[1] > 0.9 else 2)
    # ans = torch.argmax(averaged_probs)
    label = e['labels']
    tmp[i] = 1 if ans == label else 0
print(torch.mean(tmp))
'''

'\n\n# bart_large_mnli_id2labal = {0: "contradiction", 1: "neutral", 2: "entailment"}\n# id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}\n\ntmp = torch.tensor([0.0 for _ in range(len(contract_nli_dataset_all))])\ncontradiction_cnt = 0\ncontradiction_correct_cnt = 0\nlabel_sum = dict()\nfor i, e in enumerate(contract_nli_dataset_all):\n    averaged_probs = averaged_probs_record_all[i]\n    ans = 1 if averaged_probs[0] > 0.6 else 0\n    # ans = torch.argmax(averaged_probs)\n    label = e[\'labels\']\n    label_sum[id2label[label]] = label_sum.get(id2label[label], 0) + 1\n    # ans = bart_label_to_contract_nli_label(ans)\n    ans = 1 if ans == 1 else 0\n    label = 1 if label == 1 else 0\n    tmp[i] = 1 if ans == label else 0\n    if label == 1:\n        contradiction_cnt += 1\n    if ans == label == 1:\n        contradiction_correct_cnt += 1\nprint(torch.mean(tmp))\nprint(f"contradiction correct cnt: {contradiction_correct_cnt}")\nprint(f"contradiction cnt: {contradic