In [1]:
import vertexai
from vertexai.generative_models import GenerativeModel

In [2]:
from datasets import load_dataset, DatasetDict, Dataset
import json
from transformers import PerceiverTokenizer, PerceiverModel, PerceiverConfig, PerceiverPreTrainedModel, PerceiverForSequenceClassification, TrainingArguments, Trainer, \
    DataCollatorWithPadding
import re
import os
from tqdm import tqdm
import torch

ROOT_PATH = ".."

id2label = {0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}
label2id = {"Entailment": 0, "Contradiction": 1, "NotMentioned": 2}

def load_dataset_custom(dataset_name):
    if dataset_name == "contract-nli":
        def contract_nli_iterator(data):
            documents, labels = data['documents'], data['labels']
            for document in documents:
                id = document['id']
                file_name = document['file_name']
                text = document['text']
                spans = document['spans']
                annotation_sets = document['annotation_sets']
                document_type = document['document_type']
                url = document['url']
                for annotation_id, annotation_content in annotation_sets[0]['annotations'].items():
                    hypothesis = labels[annotation_id]['hypothesis']
                    choice = annotation_content['choice']
                    yield {
                        "id": id,
                        "file_name": file_name,
                        "text": text,
                        "spans": spans,
                        "document_type": document_type,
                        "url": url,
                        "hypothesis": hypothesis,
                        "labels": label2id[choice],
                    }            
        base_filepath = os.path.join(ROOT_PATH, "ignored_dir/data/contract-nli")
        train_filepath = os.path.join(base_filepath, "train.json")
        validation_filepath = os.path.join(base_filepath, "dev.json")
        test_filepath = os.path.join(base_filepath, "test.json")
        with open(train_filepath) as f:
            train_data = json.load(f)
        with open(validation_filepath) as f:
            validation_data = json.load(f)
        with open(test_filepath) as f:
            test_data = json.load(f)
        data = {
            "train": Dataset.from_generator(lambda: contract_nli_iterator(train_data)),
            "validation": Dataset.from_generator(lambda: contract_nli_iterator(validation_data)),
            "test": Dataset.from_generator(lambda: contract_nli_iterator(test_data)),
        }
        return DatasetDict(data)
    return None

contract_nli_dataset = load_dataset_custom("contract-nli")

In [3]:
import time

project_id = None

vertexai.init(project=project_id, location="europe-west2")

model = GenerativeModel("gemini-1.5-flash-001")

accurate_cnt = 0
total_cnt = 0
exception_cnt = 0
responses = []
labels = []
for test_e in (pbar:=tqdm(contract_nli_dataset['test'], total=len(contract_nli_dataset['test']))):
    premise = test_e['text']
    hypothesis = test_e['hypothesis']
    label = test_e['labels']
    while True:
        try:
            response = model.generate_content(
f"""You will be given a premise and a hypothesis. Output the relationship between the premise and the hypothesis. Your answer should be either 'Contradiction', 'Entailment', or 'NotMentioned'. Do not include anything else in your answer, including special symbols.
Premise: {premise}
Hypothesis: {hypothesis}
""")
            break
        except:
            pbar.set_postfix_str(f"sleeping {accurate_cnt} {total_cnt} {0 if not total_cnt else accurate_cnt / total_cnt } {exception_cnt}")
            time.sleep(70)
    try:
        response = response.text.strip()
        response_id = -1 if response not in label2id else label2id[response]
        if response_id == label:
            accurate_cnt += 1
        total_cnt += 1
    except:
        response = "exception"
        exception_cnt += 1
    responses.append(response)
    labels.append(label)
    pbar.set_postfix_str(f"{response} {accurate_cnt} {total_cnt} {0 if not total_cnt else accurate_cnt / total_cnt } {exception_cnt}")
print(accurate_cnt, total_cnt, accurate_cnt / total_cnt)

I0000 00:00:1724084993.490352  105917 config.cc:230] gRPC experiments enabled: call_status_override_on_cancellation, event_engine_dns, event_engine_listener, http2_stats_fix, monitoring_experiment, pick_first_new, trace_record_callops, work_serializer_clears_time_cache
100%|██████████| 2091/2091 [12:05<00:00,  2.88it/s, Entailment 1382 2091 0.6609277857484457 0]   

1382 2091 0.6609277857484457





In [4]:
print(responses)

['NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'Entailment', 'Entailment', 'Entailment', 'Entailment', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'NotMentioned', 'Contradiction', 'Entailment', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentioned', 'NotMentioned', 'Entailment', 'Entailment', 'NotMentioned', 'NotMentioned', 'Entailment', 'NotMentione

In [5]:
### Get confusion matrix

from sklearn.metrics import confusion_matrix

labels_ = []
preds_ = []
for i in range(len(labels)):
    if responses[i] in label2id:
        labels_.append(labels[i])
        preds_.append(label2id[responses[i]])

conf_mat = confusion_matrix(labels_, preds_)
print(conf_mat)

[[622  10 336]
 [ 22   5 193]
 [121  27 755]]


In [6]:
### save the data
import pickle

save_dir_path = "../ignored_dir/results/llm_contradiction_detection_on_contract_nli"
if not os.path.exists(save_dir_path):
    os.mkdir(save_dir_path)
readme_path = os.path.join(save_dir_path, "README.txt")
with open(readme_path, "w") as f:
    f.write('{0: "Entailment", 1: "Contradiction", 2: "NotMentioned"}')
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")
conf_mat_filehandler = open(conf_mat_path, 'wb') 
pickle.dump(conf_mat, conf_mat_filehandler)
conf_mat_filehandler.close()

In [7]:
### Test load

import pickle
import os

save_dir_path = "../ignored_dir/results/llm_contradiction_detection_on_contract_nli"
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")

with open(conf_mat_path, 'rb') as f:
    conf_mat_loaded = pickle.load(f)
print(conf_mat_loaded) 

[[622  10 336]
 [ 22   5 193]
 [121  27 755]]


In [8]:
accuracy = sum(conf_mat_loaded[i][i] for i in range(3)) / sum(sum(conf_mat_loaded))
print(f"accuracy: {accuracy}")

accuracy: 0.6609277857484457


In [9]:
### Make table
conf_mat = conf_mat_loaded

labels = [id2label[i] for i in range(3)]
print(labels)

latex_table = """\\begin{center}
\\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\\\
\hline
\multirow{4}{2em}{true} &  & """ + f"{labels[0]} & {labels[1]} & {labels[2]} \\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[0]} & {conf_mat[0][0]} & {conf_mat[0][1]} & {conf_mat[0][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[1]} & {conf_mat[1][0]} & {conf_mat[1][1]} & {conf_mat[1][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[2]} & {conf_mat[2][0]} & {conf_mat[2][1]} & {conf_mat[2][2]}\\\\\n" \
+ """\hline
\end{tabular}
\end{center}"""

print(latex_table)

['Entailment', 'Contradiction', 'NotMentioned']
\begin{center}
\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\
\hline
\multirow{4}{2em}{true} &  & Entailment & Contradiction & NotMentioned \\
\cline{2-5}
& Entailment & 622 & 10 & 336\\
\cline{2-5}
& Contradiction & 22 & 5 & 193\\
\cline{2-5}
& NotMentioned & 121 & 27 & 755\\
\hline
\end{tabular}
\end{center}
