# Fine-tuning a pretrained model for text classification

In this notebook, we learn how to fine-tune a pretrained language model on our own dataset. In this case, we are using the IMDB dataset for sentiment analysis. You can find more info about the dataset here: https://huggingface.co/datasets/imdb.

The model we are using is DistilBERT, which is a significantly smaller and faster version of BERT, produced through a process called knowledge distillation. It is reported to retain around 97% of BERT's language understanding capabilities.

If you are using Google Colab, make sure that you are using a GPU (Runtime > Change runtime type > Hardware accelerator > GPU).

In [1]:
# Install the required libraries
!pip install transformers
!pip install datasets




If using Google Colab: Mount Google Drive to save the fine-tuned model.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Define where to save the fine-tuned model. If you are using Colab, the model needs to be saved to Google Drive (as specified below). Otherwise, you can use a local dir.

In [3]:
import os
output_dir = os.path.join('drive', 'My Drive', 'distilbert-finetuned-imdb')

In [4]:
# Import torch and check if GPU is available
import torch
train_on_gpu = torch.cuda.is_available()
print('Train on GPU: ', train_on_gpu)

Train on GPU:  True


## 1 Data preparation

We use the Datasets library to download the data. We further split the data into traininig, validation and test sets. We only use 3000 out of 25000 training examples because otherwise fine-tuning would take too much time.

In [5]:
# Load the dataset and create the data splits
from datasets import load_dataset

imdb = load_dataset("imdb")
imdb = imdb.shuffle(seed=42)

# We use a small subset of the dataset to decrease training time: 3000 training examples and 300 validation/test examples.
train_dataset = imdb["train"].select(range(3000))
val_dataset = imdb["train"].select(range(3000, 3300))
test_dataset = imdb["test"].select(range(300))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Inspect the data to see if it looks as you expect.

We preprocess the data using a Transformers Tokenizer, which tokenizes the data and formats it for input to the model. Transformer models use sub-word tokenizers, meaning that a token can be a whole word or a part of a word. This process varies across different tokenizers, so it is important to use the correct tokenizer for your chosen model. Typically, the tokenizer name will be the same as model name. If this does not work, you can find the correct tokenizer name on the model card of your chosen model.

In this case, the tokenizer we use is distilbert-base-uncased (same as the model), and we specify that we want to use the fast version of the tokenizer.



In [6]:
# Instantiate the tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)

We apply the map method to tokenize the entire dataset at once. The data is passed in batches for faster tokenization.

In [7]:
# Prepare the text inputs for the model
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Inspect also the tokenized data to see the transformations applied. You should see lists of token IDs and attention masks.

In [14]:
# Inspect tokenized data
for i in range(5):  # Print the first 5 examples
    print("Example", i+1)
    print("Token IDs:", tokenized_train[i]['input_ids'])
    print("Attention Mask:", tokenized_train[i]['attention_mask'])
    print()


Example 1
Token IDs: [101, 2045, 2003, 2053, 7189, 2012, 2035, 2090, 3481, 3771, 1998, 6337, 2099, 2021, 1996, 2755, 2008, 2119, 2024, 2610, 2186, 2055, 6355, 6997, 1012, 6337, 2099, 3504, 15594, 2100, 1010, 3481, 3771, 3504, 4438, 1012, 6337, 2099, 14811, 2024, 3243, 3722, 1012, 3481, 3771, 1005, 1055, 5436, 2024, 2521, 2062, 8552, 1012, 1012, 1012, 3481, 3771, 3504, 2062, 2066, 3539, 8343, 1010, 2065, 2057, 2031, 2000, 3962, 12319, 1012, 1012, 1012, 1996, 2364, 2839, 2003, 5410, 1998, 6881, 2080, 1010, 2021, 2031, 1000, 17936, 6767, 7054, 3401, 1000, 1012, 2111, 2066, 2000, 12826, 1010, 2000, 3648, 1010, 2000, 16157, 1012, 2129, 2055, 2074, 9107, 1029, 6057, 2518, 2205, 1010, 2111, 3015, 3481, 3771, 3504, 2137, 2021, 1010, 2006, 1996, 2060, 2192, 1010, 9177, 2027, 9544, 2137, 2186, 1006, 999, 999, 999, 1007, 1012, 2672, 2009, 1005, 1055, 1996, 2653, 1010, 2030, 1996, 4382, 1010, 2021, 1045, 2228, 2023, 2186, 2003, 2062, 2394, 2084, 2137, 1012, 2011, 1996, 2126, 1010, 1996, 5889, 2024

In [8]:
# Use data_collector to convert our samples to PyTorch tensors and concatenate them with the correct amount of padding
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 2 Training the model

In [9]:
# Define DistilBERT as our base model and ensure the utilization of the GPU
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

if train_on_gpu:
  model = model.to('cuda')

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Define the evaluation metrics
import numpy as np
from datasets import load_metric

def compute_metrics(eval_pred):
    load_accuracy = load_metric("accuracy")
    load_f1 = load_metric("f1")

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
    return {"accuracy": accuracy, "f1": f1}

We use the Trainer class for fine-tuning. Trainer is specifically optimised for training models from the Transformers library. If you prefer to write your own training loop, that is also possible. More info here: https://huggingface.co/docs/transformers/training.

We also specify the training arguments, which define some hyperparameters and strategies. Since we are only training for two epochs with a modest number of training examples, we set evaluation after every 50 steps so that we can monitor the progress. Take a look at the documentation if you want to understand the arguments better.

In [11]:
# Define a new Trainer with all the objects we constructed so far
from transformers import TrainingArguments, Trainer

!pip install transformers[torch]
!pip install accelerate -U
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy='steps',
    logging_steps=50,
    eval_steps=50,
    save_steps=200
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)



dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [12]:
# Train and save the model
trainer.train()
trainer.save_model(output_dir=output_dir)

Step,Training Loss,Validation Loss,Accuracy,F1
50,0.5839,0.347112,0.86,0.86
100,0.3593,0.317346,0.86,0.862745
150,0.2688,0.278797,0.876667,0.877076
200,0.2901,0.275342,0.88,0.882353
250,0.1738,0.312191,0.893333,0.885714
300,0.2255,0.300177,0.893333,0.894737
350,0.1837,0.271061,0.896667,0.895623


  load_accuracy = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

## 3 Testing the model

In [13]:
# Compute the evaluation metrics
trainer.evaluate(tokenized_test)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


{'eval_loss': 0.35880711674690247,
 'eval_accuracy': 0.87,
 'eval_f1': 0.8745980707395499,
 'eval_runtime': 7.0059,
 'eval_samples_per_second': 42.821,
 'eval_steps_per_second': 2.712,
 'epoch': 2.0}

## 4 Improving the results

When fine-tuning a transformer model, we have a lot less flexibility compared to training a neural network from scratch. That is because we are taking an existing model whose architecture has already been defined, but we can still change some hyperparameters.

Try to see if you can get better results by varying training parameters like the learning rate, weight decay or the number of epochs. You can also try changing the batch size, but increasing it significantly might cause a memory crash.

Once you find the best combination of hyperparameters, try training on more data (remember: we only used a subset for faster processing.) Does more data improve the results?

In [15]:
from transformers import TrainingArguments, Trainer

# Define hyperparameters to search over
learning_rates = [1e-5, 2e-5, 3e-5]
weight_decays = [0.01, 0.001, 0.0001]
num_epochs = [2, 3, 4]

best_f1_score = 0.0
best_hyperparameters = {}

# Perform grid search over hyperparameters
for lr in learning_rates:
    for wd in weight_decays:
        for epochs in num_epochs:
            # Define training arguments
            training_args = TrainingArguments(
                output_dir=output_dir,
                learning_rate=lr,
                per_device_train_batch_size=16,
                per_device_eval_batch_size=16,
                num_train_epochs=epochs,
                weight_decay=wd,
                evaluation_strategy='steps',
                logging_steps=50,
                eval_steps=50,
                save_steps=200
            )

            # Initialize Trainer with current hyperparameters
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=tokenized_train,
                eval_dataset=tokenized_val,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_metrics,
            )

            # Train the model
            trainer.train()

            # Evaluate the model on validation data
            eval_results = trainer.evaluate()

            # Update best hyperparameters if current F1 score is higher
            if eval_results['eval_f1'] > best_f1_score:
                best_f1_score = eval_results['eval_f1']
                best_hyperparameters = {
                    'learning_rate': lr,
                    'weight_decay': wd,
                    'num_epochs': epochs
                }

# Print the best hyperparameters and corresponding F1 score
print("Best F1 Score:", best_f1_score)
print("Best Hyperparameters:", best_hyperparameters)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.1825,0.376591,0.873333,0.88125
100,0.1079,0.353259,0.896667,0.894198
150,0.0889,0.358879,0.9,0.899329
200,0.1292,0.371495,0.9,0.897959
250,0.0594,0.400575,0.903333,0.903654
300,0.067,0.402541,0.896667,0.89701
350,0.1126,0.411624,0.896667,0.89769


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.0998,0.519543,0.873333,0.875
100,0.0353,0.532015,0.893333,0.893333
150,0.041,0.466428,0.896667,0.894915
200,0.0931,0.443213,0.906667,0.903448
250,0.0142,0.46077,0.896667,0.894915
300,0.0328,0.502619,0.9,0.9
350,0.0942,0.481213,0.9,0.89726
400,0.0726,0.491225,0.9,0.899329
450,0.0523,0.490769,0.903333,0.902357
500,0.05,0.480278,0.893333,0.891156


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.0346,0.501738,0.9,0.89726
100,0.0146,0.553904,0.893333,0.892617
150,0.0184,0.562054,0.9,0.89726
200,0.0681,0.513977,0.903333,0.901024
250,0.0091,0.575301,0.896667,0.896321
300,0.0209,0.629953,0.88,0.883871
350,0.0227,0.535592,0.913333,0.909091
400,0.0517,0.513679,0.91,0.908475
450,0.0357,0.538136,0.903333,0.899654
500,0.018,0.568836,0.896667,0.895623


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.0226,0.550184,0.906667,0.900709
100,0.008,0.922014,0.866667,0.874214
150,0.0228,0.698894,0.9,0.899329
200,0.0278,0.63331,0.886667,0.885906
250,0.0022,0.720481,0.88,0.880795
300,0.0075,0.663373,0.89,0.888889
350,0.0025,0.668409,0.89,0.888889


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.0071,0.634488,0.896667,0.888889
100,0.0156,0.753161,0.893333,0.89404
150,0.0074,0.850678,0.88,0.882353
200,0.0068,0.828243,0.89,0.891089
250,0.0003,0.73316,0.9,0.899329
300,0.0002,0.760336,0.903333,0.902357
350,0.0004,0.754144,0.906667,0.905405
400,0.0027,0.747253,0.903333,0.902357
450,0.0011,0.830303,0.896667,0.89769
500,0.0012,0.789553,0.896667,0.895623


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Step,Training Loss,Validation Loss,Accuracy,F1
50,0.0004,0.811772,0.893333,0.892617
100,0.0,0.737134,0.923333,0.919861
150,0.0,0.849201,0.903333,0.901695
200,0.0,1.164783,0.87,0.875399
250,0.001,0.900215,0.9,0.896552
300,0.0015,0.968899,0.886667,0.886667
350,0.0033,0.835501,0.91,0.904594
400,0.0,0.818069,0.91,0.906574
450,0.0029,1.096574,0.89,0.893204
500,0.0006,1.098458,0.883333,0.885246


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datase

KeyboardInterrupt: 

**Ik heb de code hier gestopt omdat het anders veel te lang duurde, dit was al een uur aan runtime**

## BONUS: Fine-tune a model on a different dataset/task of your choice.