# Fine-tuning a model with the Trainer API

🤗 Transformers provides a Trainer class to help you fine-tune any of the pretrained models it provides on your dataset. Once you’ve done all the data preprocessing work in the last section, you have just a few steps left to define the Trainer

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

  from .autonotebook import tqdm as notebook_tqdm





### Training
The first step before we can define our Trainer is to define a TrainingArguments class that will contain all the hyperparameters the Trainer will use for training and evaluation. The only argument you have to provide is a directory where the trained model will be saved, as well as the checkpoints along the way.

In [2]:
from transformers import TrainingArguments

training_args = TrainingArguments("test-trainer")

The second step is to define our model = AutoModelForSequenceClassification class, with two labels

In [3]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


You will notice that you get a warning after instantiating this pretrained model. This is because BERT has not been pretrained on classifying pairs of sentences, so the head of the pretrained model has been discarded and a new head suitable for sequence classification has been added instead. The warnings indicate that some weights were not used and that some others were randomly initialized 

Once we have our model, we can define a Trainer by passing it all the objects constructed up to now — the model, the training_args, the training and validation datasets, our data_collator, and our tokenizer:

In [4]:
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

To fine-tune the model on our dataset, we just have to call the train() method of our Trainer

In [5]:
trainer.train()

  0%|          | 0/1377 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 36%|███▋      | 500/1377 [01:10<02:07,  6.89it/s]

{'loss': 0.5716, 'learning_rate': 3.184458968772695e-05, 'epoch': 1.09}


 73%|███████▎  | 1000/1377 [02:24<00:52,  7.18it/s]

{'loss': 0.3506, 'learning_rate': 1.3689179375453886e-05, 'epoch': 2.18}


100%|██████████| 1377/1377 [03:18<00:00,  6.94it/s]

{'train_runtime': 198.3724, 'train_samples_per_second': 55.471, 'train_steps_per_second': 6.941, 'train_loss': 0.39796607322942196, 'epoch': 3.0}





TrainOutput(global_step=1377, training_loss=0.39796607322942196, metrics={'train_runtime': 198.3724, 'train_samples_per_second': 55.471, 'train_steps_per_second': 6.941, 'train_loss': 0.39796607322942196, 'epoch': 3.0})

This will start the fine-tuning (which should take a couple of minutes on a GPU) and report the training loss every 500 steps. It won’t, however, tell you how well (or badly) your model is performing. This is because:

We didn’t tell the Trainer to evaluate during training by setting evaluation_strategy to either "steps" (evaluate every eval_steps) or "epoch" (evaluate at the end of each epoch).


We didn’t provide the Trainer with a compute_metrics() function to calculate a metric during said evaluation (otherwise the evaluation would just have printed the loss, which is not a very intuitive number).

# Evaluation
To get some predictions from our model, we can use the Trainer.predict() command:

In [15]:
predictions = trainer.predict(tokenized_datasets['validation'])
print(predictions.predictions.shape, predictions.label_ids.shape)

100%|██████████| 51/51 [00:01<00:00, 30.32it/s]

(408, 2) (408,)





As you can see, predictions is a two-dimensional array with shape 408 x 2 (408 being the number of elements in the dataset we used). Those are the logits for each element of the dataset we passed to predict(). To transform them into predictions that we can compare to our labels, we need to take the index with the maximum value on the second axis.

In [16]:
import numpy as np

preds = np.argmax(predictions.predictions , axis = 1)

We can now compare those preds to the labels. To build our compute_metric() function, we will rely on the metrics from the 🤗 Evaluate library. We can load the metrics associated with the MRPC dataset as easily as we loaded the dataset, this time with the evaluate.load() function. The object returned has a compute() method we can use to do the metric calculation

In [20]:
import evaluate

metric = evaluate.load("glue", "mrpc")
metric.compute(predictions=preds, references=predictions.label_ids)

{'accuracy': 0.8455882352941176, 'f1': 0.8944723618090452}

Wrapping everything together, we get our compute_metrics() function

In [21]:
def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

And to see it used in action to report metrics at the end of each epoch, here is how we define a new Trainer with this compute_metrics() function

In [22]:
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
trainer.train()

                                                  
 33%|███▎      | 459/1377 [01:08<01:52,  8.14it/s]

{'eval_loss': 0.5080621838569641, 'eval_accuracy': 0.8088235294117647, 'eval_f1': 0.8602150537634409, 'eval_runtime': 5.7137, 'eval_samples_per_second': 71.407, 'eval_steps_per_second': 8.926, 'epoch': 1.0}


 36%|███▋      | 500/1377 [01:14<01:59,  7.32it/s]Checkpoint destination directory test-trainer\checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 0.5848, 'learning_rate': 3.184458968772695e-05, 'epoch': 1.09}


                                                  
 67%|██████▋   | 918/1377 [02:19<00:59,  7.68it/s]

{'eval_loss': 0.44530072808265686, 'eval_accuracy': 0.8333333333333334, 'eval_f1': 0.8885245901639345, 'eval_runtime': 5.6809, 'eval_samples_per_second': 71.82, 'eval_steps_per_second': 8.977, 'epoch': 2.0}


 73%|███████▎  | 1000/1377 [02:31<00:50,  7.42it/s]Checkpoint destination directory test-trainer\checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 0.4022, 'learning_rate': 1.3689179375453886e-05, 'epoch': 2.18}


                                                   
100%|██████████| 1377/1377 [03:36<00:00,  6.37it/s]

{'eval_loss': 0.55373215675354, 'eval_accuracy': 0.8382352941176471, 'eval_f1': 0.886986301369863, 'eval_runtime': 5.4956, 'eval_samples_per_second': 74.241, 'eval_steps_per_second': 9.28, 'epoch': 3.0}
{'train_runtime': 216.1336, 'train_samples_per_second': 50.913, 'train_steps_per_second': 6.371, 'train_loss': 0.4314105020715607, 'epoch': 3.0}





TrainOutput(global_step=1377, training_loss=0.4314105020715607, metrics={'train_runtime': 216.1336, 'train_samples_per_second': 50.913, 'train_steps_per_second': 6.371, 'train_loss': 0.4314105020715607, 'epoch': 3.0})