In [63]:
import random
from math import ceil
from typing import Any, Dict, List

from datasets import Dataset, concatenate_datasets, load_dataset
from tqdm.contrib import tenumerate
from transformers import PreTrainedTokenizer


def binarize_labels(
      dataset: Dataset
    , labels_to_pos: List[Any]
    , labels_to_neg: List[Any]
    , pos_label: int = 1
    , neg_label: int = 0
    , sample_seed: int = 42
    , shuffle_seed: int = 42
) -> Dataset:
  
    assert 'label' in dataset.features
    assert set(labels_to_pos).isdisjoint(labels_to_neg)
    random.seed(sample_seed)

    pos_label2indices: Dict[Any, List] = {}
    neg_label2indices: Dict[Any, List] = {}
    for index, label in tenumerate(dataset['label']):
        if label in labels_to_pos:
            pos_label2indices.setdefault(label, []) \
                             .append(index)
        if label in labels_to_neg:
            neg_label2indices.setdefault(label, []) \
                             .append(index)
 
    pos_num = sum(len(indices) for indices in pos_label2indices.values())
    neg_num = sum(len(indices) for indices in neg_label2indices.values())
    sample_ratio = min(pos_num, neg_num) / max(pos_num, neg_num)

    if pos_num < neg_num:
        for label, indices in neg_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            neg_label2indices[label] = random.sample(indices, sample_size)
    else:
        for label, indices in pos_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            pos_label2indices[label] = random.sample(indices, sample_size)

    def _map_labels_to_pos(batch):
        batch['label'] = [pos_label for _ in range(len(batch['label']))]
        return batch
    
    def _map_labels_to_neg(batch):
        batch['label'] = [neg_label for _ in range(len(batch['label']))]
        return batch

    dataset_balanced_binarized = concatenate_datasets(
              [dataset.select(indices)
                      .map(_map_labels_to_pos, batched=True, num_proc=4) 
               for indices in pos_label2indices.values()] 
            + [dataset.select(indices)
                      .map(_map_labels_to_neg, batched=True, num_proc=4) 
               for indices in neg_label2indices.values()]
        )

    return dataset_balanced_binarized.shuffle(seed=shuffle_seed)


def tokenize_premises_and_hypotheses(
      batch: Dict[str, List]
    , tokenizer: PreTrainedTokenizer
):
    # assumes all labels in the batch are available in `label_to_id`

    return tokenizer(
          text=batch['premise']
        , text_pair=batch['hypothesis']
        , truncation=True
        , max_length=tokenizer.model_max_length
        , padding='max_length'                      # CHANGE
        , return_attention_mask=True
        , return_token_type_ids=True
    )



label_list = [ 'not_entailment', 'entailment' ]
label_to_id = { v: i for i, v in enumerate(label_list) }
id_to_label = { v: k for k, v in label_to_id.items() }

snli = load_dataset('stanfordnlp/snli', cache_dir='.datasets/')

In [64]:
# load without fine-tuning

from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

PRETRAINED_MODEL_NAME = 'roberta-large'
MODEL_CACHE_DIR = '.model/'

config_wo_ft = AutoConfig.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , num_labels=len(label_list)
    , finetuning_task='text-classification'
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)

tokenizer_wo_ft = AutoTokenizer.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
    , use_fast_tokenizer=True
)

model_wo_ft = AutoModelForSequenceClassification.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , config=config_wo_ft
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [65]:
snli_test_wo_ft = binarize_labels(snli['test'], labels_to_pos=[0], labels_to_neg=[1,2]) \
                  .map(lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer_wo_ft), batched=True, num_proc=4)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [66]:
import torch
from tqdm.contrib import tenumerate
from sklearn.metrics import precision_recall_fscore_support, accuracy_score


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_wo_ft.to(device)
model_wo_ft.eval()

truth = []
pred = []

with torch.no_grad():
    for i, batch in tenumerate(snli_test_wo_ft.batch(batch_size=32)):
        if i != 0 and i % 30 == 0:
            p, r, f, _ = precision_recall_fscore_support(y_true=truth, y_pred=pred, pos_label=1, average='binary', zero_division=0)
            a = accuracy_score(y_true=truth, y_pred=pred)
            print('A={}, P={}, R={}, F={}'.format(a, p, r, f))
        input_ids = torch.tensor(batch['input_ids']).to(device)
        attention_mask = torch.tensor(batch['attention_mask']).to(device)
        truth += batch['label']
        output = model_wo_ft(input_ids, attention_mask, None)
        pred += torch.argmax(output.logits, dim=1).tolist()

p, r, f, _ = precision_recall_fscore_support(y_true=truth, y_pred=pred, pos_label=1, average='binary', zero_division=0)
a = accuracy_score(y_true=truth, y_pred=pred)
print('A={}, P={}, R={}, F={}'.format(a, p, r, f))

  0%|          | 0/211 [00:00<?, ?it/s]

A=0.5166666666666667, P=0.5166666666666667, R=1.0, F=0.6813186813186813
A=0.5114583333333333, P=0.5114583333333333, R=1.0, F=0.6767746381805652
A=0.503125, P=0.503125, R=1.0, F=0.6694386694386695
A=0.5015625, P=0.5015625, R=1.0, F=0.668054110301769
A=0.49854166666666666, P=0.49854166666666666, R=1.0, F=0.6653691088558321
A=0.4982638888888889, P=0.4982638888888889, R=1.0, F=0.6651216685979142
A=0.5001488095238096, P=0.5001488095238096, R=1.0, F=0.6667989286777105
A=0.49992578298946116, P=0.49992578298946116, R=1.0, F=0.6666006927263731


In [68]:
# load with fine-tuning

from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

PRETRAINED_MODEL_NAME = '.checkpoints/save/'
MODEL_CACHE_DIR = '.model/'

config_w_ft = AutoConfig.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , num_labels=len(label_list)
    , finetuning_task='text-classification'
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)

tokenizer_w_ft = AutoTokenizer.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
    , use_fast_tokenizer=True
)

model_w_ft = AutoModelForSequenceClassification.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , config=config_wo_ft
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)

In [69]:
snli_test_w_ft = binarize_labels(snli['test'], labels_to_pos=[0], labels_to_neg=[1,2]) \
                 .map(lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer_w_ft), batched=True, num_proc=4)

  0%|          | 0/10000 [00:00<?, ?it/s]

Map (num_proc=4):   0%|          | 0/6737 [00:00<?, ? examples/s]

In [70]:
import torch
from tqdm.contrib import tenumerate
from sklearn.metrics import precision_recall_fscore_support, accuracy_score


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_w_ft.to(device)
model_w_ft.eval()

truth = []
pred = []

with torch.no_grad():
    for i, batch in tenumerate(snli_test_w_ft.batch(batch_size=32)):
        if i != 0 and i % 30 == 0:
            p, r, f, _ = precision_recall_fscore_support(y_true=truth, y_pred=pred, pos_label=1, average='binary')
            a = accuracy_score(y_true=truth, y_pred=pred)
            print('A={}, P={}, R={}, F={}'.format(a, p, r, f))
        input_ids = torch.tensor(batch['input_ids']).to(device)
        attention_mask = torch.tensor(batch['attention_mask']).to(device)
        truth += batch['label']
        output = model_w_ft(input_ids, attention_mask, None)
        pred += torch.argmax(output.logits, dim=1).tolist()

p, r, f, _ = precision_recall_fscore_support(y_true=truth, y_pred=pred, pos_label=1, average='binary')
a = accuracy_score(y_true=truth, y_pred=pred)
print('A={}, P={}, R={}, F={}'.format(a, p, r, f))

Batching examples:   0%|          | 0/6737 [00:00<?, ? examples/s]

  0%|          | 0/211 [00:00<?, ?it/s]

A=0.9260416666666667, P=0.9363449691991786, R=0.9193548387096774, F=0.9277721261444557
A=0.9317708333333333, P=0.9364102564102564, R=0.929735234215886, F=0.9330608073582013
A=0.9350694444444444, P=0.9369806094182825, R=0.9337474120082816, F=0.935361216730038
A=0.9338541666666667, P=0.9322647362978284, R=0.9361370716510904, F=0.9341968911917099
A=0.934375, P=0.9321963394342762, R=0.9364814040952779, F=0.9343339587242027
A=0.9340277777777778, P=0.9304979253112033, R=0.937630662020906, F=0.9340506768483166
A=0.934970238095238, P=0.931777909037212, R=0.9387087176435585, F=0.9352304728027271
A=0.9351343327890752, P=0.9319186560565871, R=0.9388361045130641, F=0.9353645910368289
