### Testing Entailment & Contradiction models on the same data

In [None]:
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
import wandb 
import csv
import json
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
# from transformers import DistilBertTokenizer
import numpy as np
import torch.nn as nn
# import torch.optim as optim
# from tabulate import tabulate
# from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch

In [None]:
from utils import ContractNLIDataset, ContractNLIDatasetTest

In [None]:
ENTAILMENT = 0
CONTRADICTION = 1
NOT_MENTIONED = 2

In [4]:
json_folder = "../dataset/contract-nli"
csv_folder = "../dataset/csv/all_labels"

In [None]:
test_df = pd.read_csv(f'{csv_folder}/test.csv')

In [13]:
test_df['label'].value_counts()

label
0    968
2    903
1    220
Name: count, dtype: int64

In [None]:
model_id = 'sentence-transformers/all-MiniLM-L6-v2'

folder_entailment_model = 'all-MiniLM-L6-v2'
folder_contradiction_model = 'miniLM'

entailment_model = AutoModelForSequenceClassification.from_pretrained(f'./{folder_entailment_model}_entailment')
contradiction_model = AutoModelForSequenceClassification.from_pretrained(f'./{folder_contradiction_model}_contradiction')

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
max_length = tokenizer.model_max_length
print(max_length)

512


In [None]:
test_labels = test_df['label']

In [9]:
def tokenize_data(data):
    return tokenizer(text=data['concatenated_spans'].tolist(), text_pair=data['hypothesis'].tolist(), truncation=True, padding="max_length", max_length=max_length)

In [11]:
test_encodings = tokenize_data(test_df)
test_dataset = ContractNLIDatasetTest(test_encodings)

In [None]:
trainer_entailment = Trainer(
    model=entailment_model,
    tokenizer=tokenizer,
)

trainer_contradiction = Trainer(
    model=contradiction_model,
    tokenizer=tokenizer,
)

In [None]:
# Use the Trainer.predict() method to get predictions
predictions_entailment = trainer_entailment.predict(test_dataset)
pred_labels_entailment = predictions_entailment.argmax(axis=1)

predictions_contradiction = trainer_contradiction.predict(test_dataset)
pred_labels_contradiction = predictions_contradiction.argmax(axis=1)

classes = {
    "00": {"ENTAILMENT": 0, "CONTRADICTION": 0, "NOT_MENTIONED": 0},
    "01": {"ENTAILMENT": 0, "CONTRADICTION": 0, "NOT_MENTIONED": 0},
    "10": {"ENTAILMENT": 0, "CONTRADICTION": 0, "NOT_MENTIONED": 0},
    "11": {"ENTAILMENT": 0, "CONTRADICTION": 0, "NOT_MENTIONED": 0},
}


for pred_entailment, pred_contradiction, true_label in zip(pred_labels_entailment, pred_labels_contradiction, test_labels,):
    if true_label == 0:
        classes[f"{pred_entailment}{pred_contradiction}"]["ENTAILMENT"] += 1
    
    elif true_label == 1:
        classes[f"{pred_entailment}{pred_contradiction}"]["CONTRADICTION"] += 1
        
    elif true_label == 2:
        classes[f"{pred_entailment}{pred_contradiction}"]["NOT_MENTIONED"] += 1


for key, value in classes.items():
    print(f"Predicted: {key}")
    print(f"ENTAILMENT: {value['ENTAILMENT']}")
    print(f"CONTRADICTION: {value['CONTRADICTION']}")
    print(f"NOT_MENTIONED: {value['NOT_MENTIONED']}")
    print("\n")


label_mapping = {
    (0, 0): NOT_MENTIONED,
    (1, 0): ENTAILMENT,
    (0, 1): CONTRADICTION,
    (1, 1): ENTAILMENT
}

# Use list comprehension with mapping
final_predictions = [
    label_mapping[(pred_entailment, pred_contradiction)]
    for pred_entailment, pred_contradiction in zip(pred_labels_entailment, pred_labels_contradiction)
]


# Compute and display the confusion matrix
confusion_mat = confusion_matrix(test_labels, final_predictions)
print("Confusion Matrix:\n", confusion_mat)

# Display classification report
class_names = ["ENTAILMENT", "CONTRADICTION", "NOT MENTIONED"]
print("\nClassification Report:\n")
print(classification_report(test_labels, final_predictions, target_names=class_names))

# Identify incorrect predictions
incorrect_predictions = [
    i for i, (true, pred) in enumerate(zip(test_labels, final_predictions)) if true != pred
]

print(f"\nError Analysis:")
print(f"Number of incorrect predictions: {len(incorrect_predictions)} out of {len(test_labels)}")
