In [9]:
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from enum import auto
from enum import Enum
from itertools import combinations
from pprint import pprint
from unicodedata import normalize

import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

In [10]:
nltk.download("punkt", download_dir='.nltk/')

[nltk_data] Downloading package punkt to .nltk/...
[nltk_data]   Package punkt is already up-to-date!


True

In [11]:
print('Loading datasets...')
dataset_snli = load_dataset('stanfordnlp/snli', split='test')
dataset_mnli = load_dataset('nyu-mll/multi_nli', split='validation_matched')
dataset_anli = load_dataset('facebook/anli', split='test_r3')
print('Loading datasets done.')

Loading datasets...
Loading datasets done.


In [12]:
class Label(Enum):
    ENTAILMENT = auto()
    NOT_ENTAILMENT = auto()
    NOT_AVAILABLE = auto()


@dataclass
class Example:
    premise: str
    hypothesis: str
    label: Label


class ExampleSNLI(Example):
    def __init__(self, example: dict):
        label = self._to_binary_class(example['label'])
        super().__init__(example['premise'], example['hypothesis'], label)
    
    def _to_binary_class(self, label: int) -> Label:
        return Label.ENTAILMENT if label == 0 else Label.NOT_ENTAILMENT


class ExampleMNLI(Example):
    def __init__(self, example: dict):
        label = self._to_binary_class(example['label'])
        super().__init__(example['premise'], example['hypothesis'], label)
    
    def _to_binary_class(self, label: int) -> Label:
        return Label.ENTAILMENT if label == 0 else Label.NOT_ENTAILMENT


class ExampleANLI(Example):
    def __init__(self, example: dict):
        label = self._to_binary_class(example['label'])
        super().__init__(example['premise'], example['hypothesis'], label)
    
    def _to_binary_class(self, label: int) -> Label:
        return Label.ENTAILMENT if label == 0 else Label.NOT_ENTAILMENT

In [13]:
examples_snli = [ExampleSNLI(example) for example in dataset_snli]
examples_mnli = [ExampleMNLI(example) for example in dataset_mnli]
examples_anli = [ExampleANLI(example) for example in dataset_anli]

In [14]:
class Model(ABC):
    @abstractmethod
    def predict(self, example: Example) -> int:
        pass
    
    @abstractmethod
    def predict_batch(self, examples: list[Example]) -> list[int]:
        pass


class SBERT(Model):
    def __init__(self, pretrained_model_name: str):
        self.model = CrossEncoder(pretrained_model_name)
        self.ENTAILMENT = 1 # see https://huggingface.co/cross-encoder/nli-deberta-v3-small
    
    def predict(self, example: Example) -> int:
        score = self.model.predict((example.premise, example.hypothesis))[0]
        return Label.ENTAILMENT if np.argmax(score) == self.ENTAILMENT else Label.NOT_ENTAILMENT
    
    def predict_batch(self, examples: list[Example]) -> list[int]:
        scores = self.model.predict([(example.premise, example.hypothesis) for example in examples])
        return [Label.ENTAILMENT if np.argmax(score) == self.ENTAILMENT else Label.NOT_ENTAILMENT for score in scores]

In [15]:
def select_k_sentences(text: str, k: int) -> list[str]:
    sentences = sent_tokenize(normalize(text))
    return [' '.join(comb) for comb in combinations(sentences, k)]


def sub_example(tokenizer, example: Example, l: int) -> list[Example]:
    assert l > 0, 'The number of selected sentences must be greater than 0'
    max_length = tokenizer.model_max_length

    sub_examples = []
    for k in range(1, l + 1):
        sub_premises = select_k_sentences(example.premise, k)
        for sub_premise in sub_premises:
            if len(tokenizer(sub_premise, example.hypothesis)['input_ids']) > max_length:
                continue
            sub_examples.append(Example(sub_premise, example.hypothesis, example.label))
    return sub_examples


In [None]:
def metrics(y_true: list[Label], y_pred: list[Label]) -> tuple[float, float, float, float]:
    y_true = [0 if label == Label.ENTAILMENT else 1 for label in y_true]
    y_pred = [0 if label == Label.ENTAILMENT else 1 for label in y_pred]
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
    return accuracy, precision, recall, f1


def evaluate_model_metrics(model: Model, examples: list[Example], use_batch: bool = True) -> dict:
    y_true = [example.label for example in examples]
    y_pred = model.predict_batch(examples) if use_batch else [model.predict(example) for example in tqdm(examples)]
    accuracy, precision, recall, f1 = metrics(y_true, y_pred)
    return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1 }

In [19]:
model = SBERT('cross-encoder/nli-deberta-v3-small')

print('Evaluating model on SNLI...')
metrics_snli = evaluate_model_metrics(model, examples_snli)
pprint(metrics_snli)

print('Evaluating model on MNLI...')
metrics_mnli = evaluate_model_metrics(model, examples_mnli)
pprint(metrics_mnli)

print('Evaluating model on ANLI...')
metrics_anli = evaluate_model_metrics(model, examples_anli)
pprint(metrics_anli)

Evaluating model on SNLI...
{'accuracy': 0.9413,
 'f1': 0.9554120774781618,
 'precision': 0.9626511556712077,
 'recall': 0.9482810615199035}
Evaluating model on MNLI...
{'accuracy': 0.9273560876209883,
 'f1': 0.9441878669275929,
 'precision': 0.9366361236216804,
 'recall': 0.9518623737373737}
Evaluating model on ANLI...
{'accuracy': 0.6066666666666667,
 'f1': 0.7386489479512736,
 'precision': 0.6617063492063492,
 'recall': 0.8358395989974937}
