In [9]:
from transformers import DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader
from transformers import glue_convert_examples_to_features
from transformers import glue_processors
from typing import List, Optional, Union
from dataclasses import dataclass
import numpy as np

In [2]:
classifier_path = './checkpoints/mnli_baseline_distilbert-2023-04-07_10-48-14/checkpoint-last'
config = DistilBertConfig.from_pretrained(
    classifier_path,
    num_labels=3,
    finetuning_task='mnli',
    attention_probs_dropout_prob=0,
    hidden_dropout_prob=0.1
)
tokenizer = DistilBertTokenizer.from_pretrained(
    classifier_path,
    do_lower_case=True,
)
model = DistilBertForSequenceClassification.from_pretrained(
    classifier_path,
    config=config,
    ignore_mismatched_sizes=True
)

In [3]:
@dataclass
class InputExample:
    guid: str
    text_a: str
    text_b: str
    label: Optional[str] = None
        
@dataclass(frozen=True)
class InputFeatures:
    input_ids: List[int]
    attention_mask: Optional[List[int]] = None
    token_type_ids: Optional[List[int]] = None
    label: Optional[Union[int, float]] = None

In [30]:
premise = "i'm not sure what the overnight low was"
orig_hypothesis = "I don't know how cold it got last night."
orig_label = "entailment"
hypotheses = [
    "They didn't see how long it got last day.",
    "I don't know how cold it went last night.",
    "I don't know how it had gone last night.",
    "I don't know how it stayed the last night.",
    "I knew how so it was a last night."
]

In [5]:
def load_data(premise, hypotheses, tokenizer):
    processor = glue_processors['mnli']()
    label_list = ["contradiction", "entailment", "neutral"]
    examples = []
    for i, hypothesis in enumerate(hypotheses):
        examples.append(InputExample(guid=f'test-{i}', text_a=premise, text_b=hypothesis, label='contradiction'))
    
    label_map = {label: i for i, label in enumerate(label_list)}
    labels = [label_map[example.label] for example in examples]

    batch_encoding = tokenizer(
        [(example.text_a, example.text_b) for example in examples],
        max_length=128,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True
    )

    features = []
    for i in range(len(examples)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}
        feature = InputFeatures(**inputs, label=labels[i])
        features.append(feature)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_labels = torch.tensor([f.label for f in features], dtype=torch.long)

    # dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset

In [31]:
dataset = load_data(premise, hypotheses, tokenizer)

In [32]:
eval_dataloader = DataLoader(dataset, batch_size=16)
for batch in eval_dataloader:
    model.eval()
    with torch.no_grad():
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
    _, logits = model(**inputs)[:2]
    preds = logits.detach().cpu().numpy()
    preds = np.argmax(preds, axis=1)

    print(preds.tolist())

[2, 1, 1, 1, 0]


In [16]:
label_list = ["contradiction", "entailment", "neutral"]

In [33]:
print(f'Premise: {premise}')
print(f'Original hypothesis: {orig_hypothesis}')
print(f'Label: {orig_label}')
print('--------------------')
for sentence, pred in zip(hypotheses, preds):
    print(f'{sentence} --> {label_list[pred]}')

Premise: i'm not sure what the overnight low was
Original hypothesis: I don't know how cold it got last night.
Label: entailment
--------------------
They didn't see how long it got last day. --> neutral
I don't know how cold it went last night. --> entailment
I don't know how it had gone last night. --> entailment
I don't know how it stayed the last night. --> entailment
I knew how so it was a last night. --> contradiction
