In [1]:
import vertexai
from vertexai.generative_models import GenerativeModel

In [2]:
from datasets import load_dataset, DatasetDict, Dataset
import json
from transformers import PerceiverTokenizer, PerceiverModel, PerceiverConfig, PerceiverPreTrainedModel, PerceiverForSequenceClassification, TrainingArguments, Trainer, \
    DataCollatorWithPadding
import re
import os
from tqdm import tqdm
import torch

ROOT_PATH = ".."

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {v: k for k, v in id2label.items()}

snli_dataset = load_dataset("stanfordnlp/snli")

for mode in ['train', 'validation', 'test']:
    snli_dataset[mode] = snli_dataset[mode].filter(lambda e: e['label'] != -1)

In [3]:
import time

project_id = None

vertexai.init(project=project_id, location="europe-west2")

model = GenerativeModel("gemini-1.5-flash-001")

accurate_cnt = 0
total_cnt = 0
exception_cnt = 0
responses = []
labels = []
for test_e in (pbar:=tqdm(snli_dataset['test'], total=len(snli_dataset['test']))):
    premise = test_e['premise']
    hypothesis = test_e['hypothesis']
    label = test_e['label']
    while True:
        try:
            response = model.generate_content(
f"""You will be given a premise and a hypothesis. Output the relationship between the premise and the hypothesis. Your answer should be either 'entailment', 'neutral', or 'contradiction'. Do not include anything else in your answer, including special symbols.
Premise: {premise}
Hypothesis: {hypothesis}
""")
            break
        except:
            pbar.set_postfix_str(f"sleeping {accurate_cnt} {total_cnt} {0 if not total_cnt else accurate_cnt / total_cnt } {exception_cnt}")
            time.sleep(70)
    try:
        response = response.text.strip()
        response_id = -1 if response not in label2id else label2id[response]
        if response_id == label:
            accurate_cnt += 1
        total_cnt += 1
    except:
        response = "exception"
        exception_cnt += 1
    responses.append(response)
    labels.append(label)
    pbar.set_postfix_str(f"{response} {accurate_cnt} {total_cnt} {0 if not total_cnt else accurate_cnt / total_cnt } {exception_cnt}")
print(accurate_cnt, total_cnt, accurate_cnt / total_cnt)

I0000 00:00:1724075497.002477   29801 config.cc:230] gRPC experiments enabled: call_status_override_on_cancellation, event_engine_dns, event_engine_listener, http2_stats_fix, monitoring_experiment, pick_first_new, trace_record_callops, work_serializer_clears_time_cache
100%|██████████| 9824/9824 [1:20:17<00:00,  2.04it/s, neutral 7094 9808 0.7232871125611745 16]         

7094 9808 0.7232871125611745





In [4]:
### Get confusion matrix

from sklearn.metrics import confusion_matrix

labels_ = []
preds_ = []
for i in range(len(labels)):
    if responses[i] in label2id:
        labels_.append(labels[i])
        preds_.append(label2id[responses[i]])

conf_mat = confusion_matrix(labels_, preds_)
print(conf_mat)

[[3101  258    7]
 [ 374 2817   18]
 [  20 2037 1176]]


In [5]:
### save the data
import pickle

save_dir_path = "../ignored_dir/results/llm_contradiction_detection_on_snli"
if not os.path.exists(save_dir_path):
    os.mkdir(save_dir_path)
readme_path = os.path.join(save_dir_path, "README.txt")
with open(readme_path, "w") as f:
    f.write('{0: "entailment", 1: "neutral", 2: "contradiction"}')
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")
conf_mat_filehandler = open(conf_mat_path, 'wb') 
pickle.dump(conf_mat, conf_mat_filehandler)
conf_mat_filehandler.close()

In [6]:
### Test load

import pickle
import os

save_dir_path = "../ignored_dir/results/llm_contradiction_detection_on_snli"
conf_mat_path = os.path.join(save_dir_path, "conf_mat.pkl")

with open(conf_mat_path, 'rb') as f:
    conf_mat_loaded = pickle.load(f)
print(conf_mat_loaded) 

[[3101  258    7]
 [ 374 2817   18]
 [  20 2037 1176]]


In [7]:
acc = sum([conf_mat_loaded[i][i] for i in range(3)]) / sum(sum(conf_mat_loaded))
print(f"accuracy: {acc}")

accuracy: 0.7232871125611745


In [8]:
### Make table
conf_mat = conf_mat_loaded

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
labels = [id2label[i] for i in range(3)]
print(labels)

latex_table = """\\begin{center}
\\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\\\
\hline
\multirow{4}{2em}{true} &  & """ + f"{labels[0]} & {labels[1]} & {labels[2]} \\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[0]} & {conf_mat[0][0]} & {conf_mat[0][1]} & {conf_mat[0][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[1]} & {conf_mat[1][0]} & {conf_mat[1][1]} & {conf_mat[1][2]}\\\\\n" \
+ """\cline{2-5}
& """ + f"{labels[2]} & {conf_mat[2][0]} & {conf_mat[2][1]} & {conf_mat[2][2]}\\\\\n" \
+ """\hline
\end{tabular}
\end{center}"""

print(latex_table)

['entailment', 'neutral', 'contradiction']
\begin{center}
\begin{tabular}{ |c|c|c|c|c| } 
\hline
& \multicolumn{4}{|c|}{pred} \\
\hline
\multirow{4}{2em}{true} &  & entailment & neutral & contradiction \\
\cline{2-5}
& entailment & 3101 & 258 & 7\\
\cline{2-5}
& neutral & 374 & 2817 & 18\\
\cline{2-5}
& contradiction & 20 & 2037 & 1176\\
\hline
\end{tabular}
\end{center}
