In [36]:
# train pretrained RoBERTa for sequence classification, NLI
# SNLI, MNLI, ANLI datasets for training
# code ref: https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_classification.py

In [37]:
PRETRAINED_MODEL_NAME = 'roberta-large'
DATASET_CACHE_DIR = '.datasets/'
TRAINER_OUTPUR_DIR = '.checkpoints/'

In [38]:
from datasets import load_dataset

snli = load_dataset('stanfordnlp/snli', cache_dir=DATASET_CACHE_DIR)
mnli = load_dataset('nyu-mll/multi_nli', cache_dir=DATASET_CACHE_DIR)
anli = load_dataset('facebook/anli', cache_dir=DATASET_CACHE_DIR)

In [39]:
from typing import Any, Dict, List
from transformers import PreTrainedTokenizer

def tokenize_premises_and_hypotheses(
      batch: Dict[str, List]
    , tokenizer: PreTrainedTokenizer
    , label_to_id: Dict[Any, int]
):
    # assumes all labels in the batch are available in `label_to_id`

    tokenized_batch = tokenizer(
          batch['premise']
        , batch['hypothesis']
        , truncation=True
        , max_length=tokenizer.model_max_length
        , padding='max_length'
        , return_attention_mask=True
        , return_token_type_ids=True
    )
    tokenized_batch['label'] = [label_to_id[label] for label in batch['label']]
    return tokenized_batch

def are_labels_available(
      batch: Dict[str, List]
    , label_to_id: Dict[Any, int]
):
    return [label_to_id.get(label, -1) != -1 for label in batch['label']]


In [40]:
from transformers import RobertaConfig, RobertaForSequenceClassification

label_list = [ 'entailment', 'not_entailment' ]
label_to_id = { v: i for i, v in enumerate(label_list) }
id_to_label = { v: k for k, v in label_to_id.items() }

config = RobertaConfig.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , num_labels=len(label_list)
    , finetuning_task='text-classification'
    , problem_type='single_label_classification'
)

model = RobertaForSequenceClassification.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
    , config=config
)
model.config.label2id = label_to_id
model.config.id2label = id_to_label

loading configuration file config.json from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/config.json
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "finetuning_task": "text-classification",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "problem_type": "single_label_classification",
  "transformers_version": "4.52.3",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file model.safetensors from cache at /Users/r-okamot/.cache/huggingface/hub/

In [41]:
from transformers import RobertaTokenizer

tokenizer = RobertaTokenizer.from_pretrained(
      pretrained_model_name_or_path=PRETRAINED_MODEL_NAME
)

snli_label_to_id = { 0: label_to_id['entailment'], 1: label_to_id['not_entailment'], 2: label_to_id['not_entailment'] } 
mnli_label_to_id = { 0: label_to_id['entailment'], 1: label_to_id['not_entailment'], 2: label_to_id['not_entailment'] }
anli_label_to_id = { 0: label_to_id['entailment'], 1: label_to_id['not_entailment'], 2: label_to_id['not_entailment'] }  

snli_tokenized = snli.filter(lambda batch: are_labels_available(batch, snli_label_to_id), batched=True) \
                     .map(lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer, snli_label_to_id), batched=True)
mnli_tokenized = mnli.filter(lambda batch: are_labels_available(batch, mnli_label_to_id), batched=True) \
                     .map(lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer, mnli_label_to_id), batched=True)
anli_tokenized = anli.filter(lambda batch: are_labels_available(batch, anli_label_to_id), batched=True) \
                     .map(lambda batch: tokenize_premises_and_hypotheses(batch, tokenizer, anli_label_to_id), batched=True)


loading file vocab.json from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/vocab.json
loading file merges.txt from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/tokenizer_config.json
loading file tokenizer.json from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/tokenizer.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /Users/r-okamot/.cache/huggingface/hub/models--roberta-large/snapshots/722cf37b1afa9454edce342e7895e588b6ff1d59/con

In [42]:
from datasets import concatenate_datasets

snli_train = snli_tokenized['train']
mnli_train = mnli_tokenized['train']
anli_train = concatenate_datasets([
      anli_tokenized['train_r1']
    , anli_tokenized['train_r2']
    , anli_tokenized['train_r3']
])

snli_eval = snli_tokenized['validation']
mnli_eval = concatenate_datasets([
      mnli_tokenized['validation_matched']
    , mnli_tokenized['validation_mismatched']
])
anli_eval = concatenate_datasets([
      anli_tokenized['dev_r1']
    , anli_tokenized['dev_r2']
    , anli_tokenized['dev_r3']
])

snli_test = snli_tokenized['test']
anli_test = concatenate_datasets([
      anli_tokenized['test_r1']
    , anli_tokenized['test_r2']
    , anli_tokenized['test_r3']
])

In [43]:
from transformers import  EvalPrediction, TrainingArguments, Trainer
from transformers.data import default_data_collator
import evaluate
import numpy

training_args = TrainingArguments(
      output_dir=TRAINER_OUTPUR_DIR
    , eval_strategy='epoch'
    , learning_rate=5e-5  # default
    , num_train_epochs=3.0  # default
    , per_device_train_batch_size=2
    , per_device_eval_batch_size=2
)

metrics = evaluate.combine([
      evaluate.load('accuracy')
    , evaluate.load('precision')
    , evaluate.load('recall')
    , evaluate.load('f1')
])

def compute_metrics(preds: EvalPrediction):
    preds = preds.predictions[0] if isinstance(preds.predictions, tuple) else \
            preds.predictions
    preds = numpy.argmax(preds, axis=1)
    result = metrics.compute(predictions=preds, references=preds.label_ids)
    if len(result) > 1:
        result['combined_score'] = numpy.mean(list(result.values())).item()
    return result

snli_trainer = Trainer(
      model=model
    , args=training_args
    , train_dataset=snli_train
    , eval_dataset=snli_eval
    , compute_metrics=compute_metrics
    , processing_class=tokenizer
    , data_collator=default_data_collator
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
from os import listdir
from os.path import isdir
from transformers.trainer_utils import get_last_checkpoint

last_checkpoint = None
if isdir(training_args.output_dir):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(listdir(training_args.output_dir)) > 0:
        raise ValueError(
            'Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.'.format(training_args.output_dir)
        )

snli_trainer.train(resume_from_checkpoint=last_checkpoint)

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3916212365.py, line 10)