In [1]:
from datasets import load_dataset as lds
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
import torch
from tqdm import tqdm
import numpy as np
from dataset.loader import ContractNLIExample
import json 

In [2]:


# Load dataset
dataset = ContractNLIExample.load(json.load(open('../dataset/contract-nli/dev.json','r')))
# Use BERT base model without fine-tuning
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
ds = load_dataset("snli")
# Filter out examples with label -1
ds = ds.filter(lambda example: example['label'] != -1)

In [4]:
def tokenize_function(examples):
    # Tokenize the inputs
    tokenized = tokenizer(
        examples["premise"],
        examples["hypothesis"],
        padding="max_length",
        truncation=True,
    )
    
    # Use the original labels directly
    tokenized["labels"] = examples["label"]
    return tokenized

tokenized_ds = ds.map(tokenize_function, batched=True)
tokenized_ds.set_format("torch", columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/9824 [00:00<?, ? examples/s]

Map:   0%|          | 0/9842 [00:00<?, ? examples/s]

Map:   0%|          | 0/549367 [00:00<?, ? examples/s]

In [5]:
# Create DataLoader
test_dataset = tokenized_ds["test"]
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [6]:
def evaluate(model, dataloader, original_dataset):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    all_predictions = []
    all_labels = []
    
    # Original label mapping from the dataset
    original_label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
    
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Evaluating")
    
    with torch.no_grad():
        for batch_idx, batch in progress_bar:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
            }
            labels = batch['labels'].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits
            
            batch_preds = torch.argmax(logits, dim=-1)
            
            all_predictions.extend(batch_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Calculate live metrics
            current_accuracy = accuracy_score(all_labels, all_predictions)
            current_f1 = f1_score(all_labels, all_predictions, average='weighted')
            
            progress_bar.set_postfix({'Accuracy': f'{current_accuracy:.4f}', 'F1': f'{current_f1:.4f}'})
            
            # Print examples for the first batch
            if batch_idx == 0:
                for j in range(min(10, len(batch['input_ids']))):
                    idx = batch_idx * dataloader.batch_size + j
                    if idx < len(original_dataset):
                        premise = original_dataset[idx]['premise']
                        hypothesis = original_dataset[idx]['hypothesis']
                        true_label_idx = original_dataset[idx]['label']
                        true_label = original_label_map[true_label_idx]
                        pred_label = original_label_map[batch_preds[j].item()]
                        
                        print(f"\nExample {idx + 1}:")
                        print(f"Premise: {premise}")
                        print(f"Hypothesis: {hypothesis}")
                        print(f"True label: {true_label}")
                        print(f"Predicted label: {pred_label}")
                        print(f"Logits: {logits[j].cpu().tolist()}")
                        print("-" * 50)
    
    # Calculate final metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    f1_weighted = f1_score(all_labels, all_predictions, average='weighted')
    
    print(f"\nFinal Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    print(f"F1 Score (Weighted): {f1_weighted:.4f}")
    
    return all_predictions, all_labels


In [7]:
# Run evaluation
predictions, true_labels = evaluate(model, test_dataloader, ds["test"])

Evaluating:   0%|          | 1/614 [00:01<15:53,  1.56s/it, Accuracy=0.3125, F1=0.1786]


Example 1:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: The church has cracks in the ceiling.
True label: neutral
Predicted label: contradiction
Logits: [-0.15072835981845856, -0.055166445672512054, 0.017433125525712967]
--------------------------------------------------

Example 2:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: The church is filled with song.
True label: entailment
Predicted label: contradiction
Logits: [-0.16454318165779114, -0.07607656717300415, 0.033789150416851044]
--------------------------------------------------

Example 3:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: A choir singing at a baseball game.
True label: contradiction
Predicted label: contradiction
Logits: [-0.086623415350914, -0.09536000341176987, 0.1066807210445404]
------------------------------------

Evaluating:  21%|██▏       | 131/614 [00:52<03:14,  2.48it/s, Accuracy=0.3230, F1=0.1640]


KeyboardInterrupt: 