In [None]:
# train pretrained RoBERTa for sequence classification, NLI
# SNLI, MNLI, ANLI datasets for training
# code ref: https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_classification.py

In [None]:
PRETRAINED_MODEL_NAME = 'roberta-large'
MODEL_CACHE_DIR = '.model/'
DATASET_CACHE_DIR = '.datasets/'
TRAINER_OUTPUT_DIR = '.checkpoints/'

In [None]:
from datasets import load_dataset

snli = load_dataset('stanfordnlp/snli', cache_dir=DATASET_CACHE_DIR)

In [None]:
import random
from math import ceil
from typing import Any, Dict, List

from datasets import Dataset, concatenate_datasets
from tqdm.contrib import tenumerate
from transformers import PreTrainedTokenizer


def binarize_labels(
      dataset: Dataset
    , labels_to_pos: List[Any]
    , labels_to_neg: List[Any]
    , pos_label: int = 1
    , neg_label: int = 0
    , sample_seed: int = 42
    , shuffle_seed: int = 42
) -> Dataset:
  
    assert 'label' in dataset.features
    assert set(labels_to_pos).isdisjoint(labels_to_neg)
    random.seed(sample_seed)

    pos_label2indices: Dict[Any, List] = {}
    neg_label2indices: Dict[Any, List] = {}
    for index, label in tenumerate(dataset['label']):
        if label in labels_to_pos:
            pos_label2indices.setdefault(label, []) \
                             .append(index)
        if label in labels_to_neg:
            neg_label2indices.setdefault(label, []) \
                             .append(index)
 
    pos_num = sum(len(indices) for indices in pos_label2indices.values())
    neg_num = sum(len(indices) for indices in neg_label2indices.values())
    sample_ratio = min(pos_num, neg_num) / max(pos_num, neg_num)

    if pos_num < neg_num:
        for label, indices in neg_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            neg_label2indices[label] = random.sample(indices, sample_size)
    else:
        for label, indices in pos_label2indices.items():
            sample_size = ceil(sample_ratio * len(indices))
            pos_label2indices[label] = random.sample(indices, sample_size)

    def _map_labels_to_pos(batch):
        batch['label'] = [pos_label for _ in range(len(batch['label']))]
        return batch
    
    def _map_labels_to_neg(batch):
        batch['label'] = [neg_label for _ in range(len(batch['label']))]
        return batch

    dataset_balanced_binarized = concatenate_datasets(
              [dataset.select(indices)
                      .map(_map_labels_to_pos, batched=True, num_proc=4) 
               for indices in pos_label2indices.values()] 
            + [dataset.select(indices)
                      .map(_map_labels_to_neg, batched=True, num_proc=4) 
               for indices in neg_label2indices.values()]
        )

    return dataset_balanced_binarized.shuffle(seed=shuffle_seed)


def tokenize_premises_and_hypotheses(
      batch: Dict[str, List]
    , tokenizer: PreTrainedTokenizer
):
    # assumes all labels in the batch are available in `label_to_id`

    return tokenizer(
          text=batch['premise']
        , text_pair=batch['hypothesis']
        , truncation=True
        , max_length=tokenizer.model_max_length
        , padding=False                          # pad later dynamically with collator
        , return_attention_mask=True
        , return_token_type_ids=True
    )

In [None]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

# make sure to `entailment` is the SECOND for positive class
label_list = [ 'not_entailment', 'entailment' ]
label_to_id = { v: i for i, v in enumerate(label_list) }
id_to_label = { v: k for k, v in label_to_id.items() }

config = AutoConfig.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , num_labels=len(label_list)
    , finetuning_task='text-classification'
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)

tokenizer = AutoTokenizer.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
    , use_fast_tokenizer=True
)

model = AutoModelForSequenceClassification.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , config=config
    , cache_dir=MODEL_CACHE_DIR
    , revision='main'
)
model.config.label2id = label_to_id
model.config.id2label = id_to_label

In [None]:
snli_labels_to_pos = [0]     # `entailment` 
snli_labels_to_neg = [1, 2]  # `neutral`, `contradiction` 

snli_train = binarize_labels(
              snli['train']
            , labels_to_pos=snli_labels_to_pos
            , labels_to_neg=snli_labels_to_neg
            , pos_label=label_to_id['entailment']
            , neg_label=label_to_id['not_entailment']
      ) \
      .map(
              lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer)
            , batched=True
            , num_proc=4
      )

snli_eval = binarize_labels(
              snli['validation']
            , labels_to_pos=snli_labels_to_pos
            , labels_to_neg=snli_labels_to_neg
            , pos_label=label_to_id['entailment']
            , neg_label=label_to_id['not_entailment']
      ) \
      .map(
              lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer)
            , batched=True
            , num_proc=4
      )

snli_test = binarize_labels(
              snli['test']
            , labels_to_pos=snli_labels_to_pos
            , labels_to_neg=snli_labels_to_neg
            , pos_label=label_to_id['entailment']
            , neg_label=label_to_id['not_entailment']
      ) \
      .map(
              lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer)
            , batched=True
            , num_proc=4
      )

In [None]:
# check dataset balance

from collections import Counter
print(Counter(snli_train['label']))
print(Counter(snli_eval['label']))
print(Counter(snli_test['label']))

In [None]:
import evaluate
import numpy
import torch
from transformers import EvalPrediction, Trainer, TrainingArguments


use_mixed_precision = True and torch.cuda.is_available()
print(f'Using mixed precision: {use_mixed_precision}')

training_args = TrainingArguments(
          output_dir=TRAINER_OUTPUT_DIR
        , overwrite_output_dir=True         # to overwrite the output directory
        , do_train=True
        , do_eval=True
        , eval_strategy='epoch'             # to evaluate every epoch
        , save_strategy='epoch'             # to save the model every epoch
        , learning_rate=1e-5                # equivalent to DocNLI
        , num_train_epochs=5.0              # equivalent to DocNLI
        , per_device_train_batch_size=16
        , gradient_accumulation_steps=1     # batch_size ~ this * per_device_train_epoch_batch_size
        , per_device_eval_batch_size=16
        , fp16=use_mixed_precision          # to use mixed precision training
    )

metrics = evaluate.combine([
          evaluate.load('accuracy')
        , evaluate.load('precision')
        , evaluate.load('recall')
        , evaluate.load('f1')
    ])

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else \
            p.predictions
    preds = numpy.argmax(preds, axis=1)
    result = metrics.compute(predictions=preds, references=p.label_ids)
    return result

In [None]:
# from os import listdir
# from os.path import isdir

# from transformers.trainer_utils import get_last_checkpoint


# last_checkpoint = None
# if isdir(training_args.output_dir):
#     last_checkpoint = get_last_checkpoint(training_args.output_dir)
#     if last_checkpoint is None and len(listdir(training_args.output_dir)) > 0:
#         raise ValueError(
#                 'Output directory ({}) already exists and is not empty. ' \
#                 'Use --overwrite_output_dir to overcome.'
#                 .format(training_args.output_dir)
#             )

In [None]:
from transformers.data import DataCollatorWithPadding

data_collator = None
if training_args.fp16:
    data_collator = DataCollatorWithPadding(tokenizer)

snli_trainer = Trainer(
          model=model
        , args=training_args
        , train_dataset=snli_train
        , eval_dataset=snli_eval
        , compute_metrics=compute_metrics
        , processing_class=tokenizer
        , data_collator=data_collator
    )

In [None]:
import torch

try:
    train_result = snli_trainer.train(resume_from_checkpoint=None)
    snli_trainer.save_model()
    snli_trainer.save_metrics('train', train_result.metrics)

except KeyboardInterrupt:
    # HACK: when you interrrpt the training, GPU may not be initialized properly
    del model
    del snli_trainer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    raise KeyboardInterrupt('Training interrupted by user.')

In [None]:
import numpy as np

preds = snli_trainer.predict(snli_test).predictions
preds = np.argmax(preds, axis=1)

results = []
for index, item in enumerate(preds):
    label = label_list[item]
    results.append((index, label))

with open('preds.csv', mode='w') as f:
    for index, label in results:
        f.write('{},{}'.format(index, label))