In [2]:
from datasets import load_dataset as lds
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
import torch
from tqdm import tqdm
import numpy as np
from dataset.loader import ContractNLIExample
import json

In [3]:
# Load dataset and model
dataset = ContractNLIExample.load(json.load(open('../dataset/contract-nli/dev.json','r')))
model_name = "microsoft/deberta-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)


  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Some weights of the model checkpoint at microsoft/deberta-large-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
ds = lds("snli")
# Filter out examples with label -1
ds = ds.filter(lambda example: example['label'] != -1)

In [5]:
def tokenize_function(examples):
    # Tokenize the inputs
    tokenized = tokenizer(
        examples["premise"],
        examples["hypothesis"],
        padding="max_length",
        truncation=True,
    )
    
    # Remap labels to match the model's expected label mapping
    label_remap = {0: 2, 1: 1, 2: 0}  # From SNLI labels to DeBERTa labels
    tokenized["labels"] = [label_remap[label] for label in examples["label"]]
    return tokenized

tokenized_ds = ds.map(tokenize_function, batched=True)
tokenized_ds.set_format("torch", columns=['input_ids', 'attention_mask', 'labels'])

Map:   0%|          | 0/549367 [00:00<?, ? examples/s]

Map:   0%|          | 0/9824 [00:00<?, ? examples/s]

In [6]:
# Create DataLoader
test_dataset = tokenized_ds["test"]
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [7]:
def interpret_logits(logits):
    # Convert logits to CPU and extract values
    logits_cpu = logits.cpu().tolist()
    
    # If the first logit (index 0) is highest, it's entailment
    if logits_cpu[0] > logits_cpu[1] and logits_cpu[0] > logits_cpu[2]:
        return 0  # entailment
    # If the second logit is highest, it's neutral
    elif logits_cpu[1] > logits_cpu[0] and logits_cpu[1] > logits_cpu[2]:
        return 1  # neutral
    # Otherwise, it's contradiction
    else:
        return 2  # contradiction

In [8]:
def evaluate(model, dataloader, original_dataset):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    all_predictions = []
    all_labels = []
    
    # Model's label mapping
    model_label_map = {0: "contradiction", 1: "neutral", 2: "entailment"}
    # Original label mapping from the dataset
    original_label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
    
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Evaluating")
    
    with torch.no_grad():
        for batch_idx, batch in progress_bar:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
            }
            labels = batch['labels'].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits
            
            batch_preds = torch.argmax(logits, dim=-1)
            
            all_predictions.extend(batch_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Calculate live metrics
            current_accuracy = accuracy_score(all_labels, all_predictions)
            current_f1 = f1_score(all_labels, all_predictions, average='weighted')
            
            progress_bar.set_postfix({'Accuracy': f'{current_accuracy:.4f}', 'F1': f'{current_f1:.4f}'})
            
            # Print examples for the first batch
            if batch_idx == 0:
                for j in range(min(10, len(batch['input_ids']))):
                    idx = batch_idx * dataloader.batch_size + j
                    if idx < len(original_dataset):
                        premise = original_dataset[idx]['premise']
                        hypothesis = original_dataset[idx]['hypothesis']
                        true_label_idx = original_dataset[idx]['label']
                        true_label = original_label_map[true_label_idx]
                        pred_label = model_label_map[batch_preds[j].item()]
                        
                        print(f"\nExample {idx + 1}:")
                        print(f"Premise: {premise}")
                        print(f"Hypothesis: {hypothesis}")
                        print(f"True label: {true_label}")
                        print(f"Predicted label: {pred_label}")
                        print(f"Logits: {logits[j].cpu().tolist()}")
                        print("-" * 50)
    
    # Calculate final metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    f1_weighted = f1_score(all_labels, all_predictions, average='weighted')
    
    print(f"\nFinal Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    print(f"F1 Score (Weighted): {f1_weighted:.4f}")
    
    return all_predictions, all_labels

In [9]:
# Run evaluation
predictions, true_labels = evaluate(model, test_dataloader, ds["test"])

Evaluating:   0%|          | 1/614 [00:02<26:25,  2.59s/it, Accuracy=0.8750, F1=0.8818]


Example 1:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: The church has cracks in the ceiling.
True label: neutral
Predicted label: neutral
Logits: [-0.23674659430980682, 3.084219217300415, -2.8140945434570312]
--------------------------------------------------

Example 2:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: The church is filled with song.
True label: entailment
Predicted label: entailment
Logits: [-3.5898513793945312, 1.5226945877075195, 2.2780914306640625]
--------------------------------------------------

Example 3:
Premise: This church choir sings to the masses as they sing joyous songs from the book at a church.
Hypothesis: A choir singing at a baseball game.
True label: contradiction
Predicted label: contradiction
Logits: [5.034609317779541, -2.1406774520874023, -2.6638550758361816]
--------------------------------------------------

Exam

Evaluating:   1%|          | 6/614 [00:14<24:11,  2.39s/it, Accuracy=0.8854, F1=0.8849]


KeyboardInterrupt: 

In [11]:
print(model)

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


In [None]:
print(model.config)
