# Fine-tune ALBERT for sentence-pair classification

## Introduction 

You will learn in this notebook how to fine-tune ALBERT and other BERT-based models for the **sentence-pair classification** task. This PyTorch implementation leverages the Hugging face *transformers* and *datasets* libraries to download pre-trained models, enable quick research experiments, access datasets and evaluation metrics.

This task is part of the semantic textual similarity problem. You have two pair of sentences and you want to model the textual interaction between them.

The dataset used in this notebook is Microsoft Research Paraphrase Corpus (MRPC) which is part of the GLUE benchmark : you have two sentences and you want to predict if one sentence is the paraphrase of the other one. The evaluation metrics are F1 and accuracy.

You should be able to reach on the validation set **91.19** as F1 score (the score reported in the ALBERT paper is 90.9) and **87.5** as accuracy. The fine-tuning takes 35 seconds per epoch and the inference takes 2 seconds.

The main features of this tutorial are : 
- End-to-end ML implementation (training, validation, prediction, evaluation)
- Easy adaptability to your own datasets
- Facilitation of quick experiments with other BERT-based models (BERT, ALBERT, ...)
- Quick training with limited computational resources (mixed-precision, gradient accumulation, ...)
- Multi-GPU execution
- Threshold choice for the classification decision (not necessarily 0.5)
- Freeze BERT layers and only update the classification layer weights or update all the weights
- Reproducible results with seed settings

#### Sections

1. [Installation of libraries and imports](#section01)

2. [Loading the dataset](#section02)

3. [Classes and functions](#section03)

4. [Parameters](#section04)

5. [Training and validation](#section05)

6. [Prediction](#section06)

7. [Evaluation](#section07)

8. [Experiments' ideas](#section08)

9. [Limitations](#section09)

10. [Future works](#section10)



## Installation of libraries and imports

In [None]:
!pip3 install datasets
!pip3 install transformers

In [10]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import copy
import torch.optim as optim
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset, load_metric

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [11]:
# Check that we are using 100% of GPU memory footprint support libraries/code
# from https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip -q install gputil
!pip -q install psutil
!pip -q install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
    process = psutil.Process(os.getpid())
    print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
    print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

Gen RAM Free: 19.9 GB  | Proc size: 171.4 MB
GPU RAM Free: 9479MB | Used: 587MB | Util   6% | Total 10240MB



In case GPU utilisation (Util) is not at 0%, you can uncomment and run the following line to kill all processes to get the full GPU afterwards. Make sure to comment out the line again to not constantly crash the notebook on purpose.

In [None]:
# !kill -9 -1

## Loading the dataset

In [None]:
# Load the MRPC dataset (train, validation and test)
dataset = load_dataset('glue', 'mrpc')

In [None]:
split = dataset['train'].train_test_split(test_size=0.1, seed=1)  # split the original training data for validation
train = split['train']  # 90 % of the original training data
val = split['test']   # 10 % of the original training data
test = dataset['validation']  # the original validation data is used as test data because the test labels are not available with the datasets library

# Transform data into pandas dataframes
df_train = pd.DataFrame(train)
df_val = pd.DataFrame(val)
df_test = pd.DataFrame(test)

If you want to use your own dataset, you can upload it on the left of the notebook. 

For now, only csv files are handled, you need to upload three files : training data, validation data and test data (with or without labels)

Here is a script to load data from csv files with the headers below : 
- sentence1
- sentence2 
- label

In [None]:
# Load your dataset from csv files

# Some useful UNIX commands : 
# !pwd -> print working directory
# !ls -> list directory contents of files and directories
# %cd -> change the directory/folder of the terminal's shell


# path_to_train_data = '/content/...'
# path_to_val_data = '/content/...'
# path_to_test_data = '/content/...'

# delimiter = ";" 

# df_train = pd.read_csv(path_to_train_data, delimiter=delimiter)
# df_val = pd.read_csv(path_to_val_data, delimiter=delimiter)
# df_test = pd.read_csv(path_to_test_data, delimiter=delimiter)

In [None]:
print(df_train.shape)
print(df_val.shape)
print(df_test.shape)

(3301, 4)
(367, 4)
(408, 4)


In [None]:
df_train.head()

Unnamed: 0,idx,label,sentence1,sentence2
0,3349,1,"The American troops , who also defend the mayo...","The American troops from 3-15 infantry , who a..."
1,1806,0,"Last week , his lawyers asked Warner to grant ...","Last week , his lawyers asked Gov. Mark R. War..."
2,2681,1,Their election increases the board from seven ...,Their appointments increase Berkshire 's board...
3,2001,1,"Dolores Mahoy , 68 , of Colorado Springs , Col...","Dolores E. Mahoy , 68 , of Colorado Springs is..."
4,2993,1,""" They were an inspirational couple , selfless...",He said : “ They were an inspirational couple ...


## Classes and functions

In [None]:
class CustomDataset(Dataset):

    def __init__(self, data, maxlen, with_labels=True, bert_model='albert-base-v2'):

        self.data = data  # pandas dataframe
        #Initialize the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model)  

        self.maxlen = maxlen
        self.with_labels = with_labels 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        # Selecting sentence1 and sentence2 at the specified index in the data frame
        sent1 = str(self.data.loc[index, 'sentence1'])
        sent2 = str(self.data.loc[index, 'sentence2'])

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_pair = self.tokenizer(sent1, sent2, 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
        
        token_ids = encoded_pair['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded_pair['attention_mask'].squeeze(0)  # binary tensor with "0" for padded values and "1" for the other values
        token_type_ids = encoded_pair['token_type_ids'].squeeze(0)  # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        if self.with_labels:  # True if the dataset has labels
            label = self.data.loc[index, 'label']
            return token_ids, attn_masks, token_type_ids, label  
        else:
            return token_ids, attn_masks, token_type_ids

In [None]:
class SentencePairClassifier(nn.Module):

    def __init__(self, bert_model="albert-base-v2", freeze_bert=False):
        super(SentencePairClassifier, self).__init__()
        #  Instantiating BERT-based model object
        self.bert_layer = AutoModel.from_pretrained(bert_model)

        #  Fix the hidden-state size of the encoder outputs (If you want to add other pre-trained models here, search for the encoder output size)
        if bert_model == "albert-base-v2":  # 12M parameters
            hidden_size = 768
        elif bert_model == "albert-large-v2":  # 18M parameters
            hidden_size = 1024
        elif bert_model == "albert-xlarge-v2":  # 60M parameters
            hidden_size = 2048
        elif bert_model == "albert-xxlarge-v2":  # 235M parameters
            hidden_size = 4096
        elif bert_model == "bert-base-uncased": # 110M parameters
            hidden_size = 768

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(hidden_size, 1)

        self.dropout = nn.Dropout(p=0.1)

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):
        '''
        Inputs:
            -input_ids : Tensor  containing token ids
            -attn_masks : Tensor containing attention masks to be used to focus on non-padded values
            -token_type_ids : Tensor containing token type ids to be used to identify sentence1 and sentence2
        '''

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks, token_type_ids)

        # Feeding to the classifier layer the last layer hidden-state of the [CLS] token further processed by a
        # Linear Layer and a Tanh activation. The Linear layer weights were trained from the sentence order prediction (ALBERT) or next sentence prediction (BERT)
        # objective during pre-training.
        logits = self.cls_layer(self.dropout(pooler_output))

        return logits

In [None]:
def set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def evaluate_loss(net, device, criterion, dataloader):
    net.eval()

    mean_loss = 0
    count = 0

    with torch.no_grad():
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(dataloader)):
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
            logits = net(seq, attn_masks, token_type_ids)
            mean_loss += criterion(logits.squeeze(-1), labels.float()).item()
            count += 1

    return mean_loss / count

In [None]:
print("Creation of the models' folder...")
!mkdir models

Creation of the models' folder...


Link for mixed precision training, gradient scaling and gradient accumulation  : https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples

If you would like to learn more about Training Neural Nets on Larger Batches, I suggest reading this post of Thomas Wolf :
https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255

In [None]:
def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):

    best_loss = np.Inf
    best_ep = 1
    nb_iterations = len(train_loader)
    print_every = nb_iterations // 5  # print the training loss 5 times per epoch
    iters = []
    train_losses = []
    val_losses = []

    scaler = GradScaler()

    for ep in range(epochs):

        net.train()
        running_loss = 0.0
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):

            # Converting to cuda tensors
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
    
            # Enables autocasting for the forward pass (model + loss)
            with autocast():
                # Obtaining the logits from the model
                logits = net(seq, attn_masks, token_type_ids)

                # Computing loss
                loss = criterion(logits.squeeze(-1), labels.float())
                loss = loss / iters_to_accumulate  # Normalize the loss because it is averaged

            # Backpropagating the gradients
            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            if (it + 1) % iters_to_accumulate == 0:
                # Optimization step
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, opti.step() is then called,
                # otherwise, opti.step() is skipped.
                scaler.step(opti)
                # Updates the scale for next iteration.
                scaler.update()
                # Adjust the learning rate based on the number of iterations.
                lr_scheduler.step()
                # Clear gradients
                opti.zero_grad()


            running_loss += loss.item()

            if (it + 1) % print_every == 0:  # Print training loss information
                print()
                print("Iteration {}/{} of epoch {} complete. Loss : {} "
                      .format(it+1, nb_iterations, ep+1, running_loss / print_every))

                running_loss = 0.0


        val_loss = evaluate_loss(net, device, criterion, val_loader)  # Compute validation loss
        print()
        print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss))

        if val_loss < best_loss:
            print("Best validation loss improved from {} to {}".format(best_loss, val_loss))
            print()
            net_copy = copy.deepcopy(net)  # save a copy of the model
            best_loss = val_loss
            best_ep = ep + 1

    # Saving the model
    path_to_model='models/{}_lr_{}_val_loss_{}_ep_{}.pt'.format(bert_model, lr, round(best_loss, 5), best_ep)
    torch.save(net_copy.state_dict(), path_to_model)
    print("The model has been saved in {}".format(path_to_model))

    del loss
    torch.cuda.empty_cache()

## Parameters

In [None]:
bert_model = "albert-base-v2"  # 'albert-base-v2', 'albert-large-v2', 'albert-xlarge-v2', 'albert-xxlarge-v2', 'bert-base-uncased', ...
freeze_bert = False  # if True, freeze the encoder weights and only update the classification layer weights
maxlen = 128  # maximum length of the tokenized input sentence pair : if greater than "maxlen", the input is truncated and else if smaller, the input is padded
bs = 16  # batch size
iters_to_accumulate = 2  # the gradient accumulation adds gradients over an effective batch of size : bs * iters_to_accumulate. If set to "1", you get the usual batch size
lr = 2e-5  # learning rate
epochs = 4  # number of training epochs

## Training and validation

Link for the AdamW optimizer and the learning rate scheduler :
https://huggingface.co/transformers/main_classes/optimizer_schedules.html

In [None]:
#  Set all seeds to make reproducible results
set_seed(1)

# Creating instances of training and validation set
print("Reading training data...")
train_set = CustomDataset(df_train, maxlen, bert_model)
print("Reading validation data...")
val_set = CustomDataset(df_val, maxlen, bert_model)
# Creating instances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size=bs, num_workers=5)
val_loader = DataLoader(val_set, batch_size=bs, num_workers=5)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SentencePairClassifier(bert_model, freeze_bert=freeze_bert)

if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)

net.to(device)

criterion = nn.BCEWithLogitsLoss()
opti = AdamW(net.parameters(), lr=lr, weight_decay=1e-2)
num_warmup_steps = 0 # The number of steps for the warmup phase.
num_training_steps = epochs * len(train_loader)  # The total number of training steps
t_total = (len(train_loader) // iters_to_accumulate) * epochs  # Necessary to take into account Gradient accumulation
lr_scheduler = get_linear_schedule_with_warmup(optimizer=opti, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)

train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate)

Reading training data...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=684.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=760289.0, style=ProgressStyle(descripti…


Reading validation data...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=47376696.0, style=ProgressStyle(descrip…




 20%|██        | 42/207 [00:07<00:27,  6.03it/s]


Iteration 41/207 of epoch 1 complete. Loss : 0.3171079242374839 


 40%|████      | 83/207 [00:14<00:20,  6.00it/s]


Iteration 82/207 of epoch 1 complete. Loss : 0.28335858190931923 


 60%|█████▉    | 124/207 [00:20<00:13,  5.97it/s]


Iteration 123/207 of epoch 1 complete. Loss : 0.2626715171627882 


 80%|███████▉  | 165/207 [00:27<00:07,  5.90it/s]


Iteration 164/207 of epoch 1 complete. Loss : 0.22988062079359844 


100%|█████████▉| 206/207 [00:34<00:00,  5.94it/s]


Iteration 205/207 of epoch 1 complete. Loss : 0.22290468016048756 


100%|██████████| 207/207 [00:34<00:00,  5.93it/s]
100%|██████████| 23/23 [00:01<00:00, 12.30it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 1 complete! Validation Loss : 0.35299917602020764
Best validation loss improved from inf to 0.35299917602020764



 20%|██        | 42/207 [00:07<00:28,  5.77it/s]


Iteration 41/207 of epoch 2 complete. Loss : 0.19836873188614845 


 40%|████      | 83/207 [00:14<00:21,  5.78it/s]


Iteration 82/207 of epoch 2 complete. Loss : 0.16286275281411847 


 60%|█████▉    | 124/207 [00:21<00:14,  5.76it/s]


Iteration 123/207 of epoch 2 complete. Loss : 0.15275843041699108 


 80%|███████▉  | 165/207 [00:28<00:07,  5.74it/s]


Iteration 164/207 of epoch 2 complete. Loss : 0.13304086556521857 


100%|█████████▉| 206/207 [00:35<00:00,  5.69it/s]


Iteration 205/207 of epoch 2 complete. Loss : 0.13053428790554766 


100%|██████████| 207/207 [00:36<00:00,  5.74it/s]
100%|██████████| 23/23 [00:01<00:00, 11.73it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 2 complete! Validation Loss : 0.43609277385732403


 20%|██        | 42/207 [00:07<00:28,  5.69it/s]


Iteration 41/207 of epoch 3 complete. Loss : 0.13481714503794182 


 40%|████      | 83/207 [00:14<00:21,  5.78it/s]


Iteration 82/207 of epoch 3 complete. Loss : 0.08921206479028958 


 60%|█████▉    | 124/207 [00:21<00:14,  5.78it/s]


Iteration 123/207 of epoch 3 complete. Loss : 0.08746417448288057 


 80%|███████▉  | 165/207 [00:28<00:07,  5.79it/s]


Iteration 164/207 of epoch 3 complete. Loss : 0.0712746214348732 


100%|█████████▉| 206/207 [00:36<00:00,  5.81it/s]


Iteration 205/207 of epoch 3 complete. Loss : 0.06408452383446984 


100%|██████████| 207/207 [00:36<00:00,  5.71it/s]
100%|██████████| 23/23 [00:01<00:00, 11.89it/s]
  0%|          | 0/207 [00:00<?, ?it/s]


Epoch 3 complete! Validation Loss : 0.35006705865911814
Best validation loss improved from 0.35299917602020764 to 0.35006705865911814



 20%|██        | 42/207 [00:07<00:28,  5.73it/s]


Iteration 41/207 of epoch 4 complete. Loss : 0.06930875144444587 


 40%|████      | 83/207 [00:14<00:21,  5.77it/s]


Iteration 82/207 of epoch 4 complete. Loss : 0.04216044627856917 


 60%|█████▉    | 124/207 [00:21<00:14,  5.75it/s]


Iteration 123/207 of epoch 4 complete. Loss : 0.038601048358875074 


 80%|███████▉  | 165/207 [00:28<00:07,  5.82it/s]


Iteration 164/207 of epoch 4 complete. Loss : 0.041417054536684254 


100%|█████████▉| 206/207 [00:35<00:00,  5.78it/s]


Iteration 205/207 of epoch 4 complete. Loss : 0.03673171890308944 


100%|██████████| 207/207 [00:36<00:00,  5.75it/s]
100%|██████████| 23/23 [00:01<00:00, 11.96it/s]



Epoch 4 complete! Validation Loss : 0.36528593172197754
The model has been saved in models/albert-base-v2_lr_2e-05_val_loss_0.35007_ep_3.pt


You can download the model saved in the folder "models" by browsing the files on the left of the colab notebook

In [None]:
# If you encounter a CUDA out of memory error: 
# - uncomment the kill command, run the "kill" command (and comment it)
# - reduce the batch size
# - then run all cells from the begining 

# If you get an ugly print of tqdm (all iterations are showed), follow the above first and last steps

printm()
# !kill -9 -1

## Prediction

In [None]:
print("Creation of the results' folder...")
!mkdir results

Creation of the results' folder...


In [None]:
def get_probs_from_logits(logits):
    """
    Converts a tensor of logits into an array of probabilities by applying the sigmoid function
    """
    probs = torch.sigmoid(logits.unsqueeze(-1))
    return probs.detach().cpu().numpy()

def test_prediction(net, device, dataloader, with_labels=True, result_file="results/output.txt"):
    """
    Predict the probabilities on a dataset with or without labels and print the result in a file
    """
    net.eval()
    w = open(result_file, 'w')
    probs_all = []

    with torch.no_grad():
        if with_labels:
            for seq, attn_masks, token_type_ids, _ in tqdm(dataloader):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()
        else:
            for seq, attn_masks, token_type_ids in tqdm(dataloader):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()

    w.writelines(str(prob)+'\n' for prob in probs_all)
    w.close()

I'm sharing below an ALBERT pre-trained model (45 Mo) so you can reproduce my results on the MRPC validation set (**91.19** as F1 score and **87.5** as accuracy). It's just in case but if all the code run as expected, you should get after the model training the correct model in the *models* folder

You can download it and upload it (~ 3 minutes) in the *models* folder by browsing the files on the left of the colab notebook :

https://drive.google.com/file/d/1AcRLGvALAH3BVSiDVjY_b8CggJgVfksp/view?usp=sharing

In [None]:
path_to_model = '/content/models/albert-base-v2_lr_2e-05_val_loss_0.35007_ep_3.pt'  
# path_to_model = '/content/models/...'  # You can add here your trained model

path_to_output_file = 'results/output.txt'

print("Reading test data...")
test_set = CustomDataset(df_test, maxlen, bert_model)
test_loader = DataLoader(test_set, batch_size=bs, num_workers=5)

model = SentencePairClassifier(bert_model)
if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

print()
print("Loading the weights of the model...")
model.load_state_dict(torch.load(path_to_model))
model.to(device)

print("Predicting on test data...")
test_prediction(net=model, device=device, dataloader=test_loader, with_labels=True,  # set the with_labels parameter to False if your want to get predictions on a dataset without labels
                result_file=path_to_output_file)
print()
print("Predictions are available in : {}".format(path_to_output_file))

Reading test data...


  0%|          | 0/26 [00:00<?, ?it/s]


Loading the weights of the model...
Predicting on test data...


100%|██████████| 26/26 [00:02<00:00, 12.07it/s]


Predictions are available in : results/output.txt





You can download the predictions saved in the folder "results" by browsing the files on the left of the colab notebook

## Evaluation

In [None]:
path_to_output_file = 'results/output.txt'  # path to the file with prediction probabilities

labels_test = df_test['label']  # true labels

probs_test = pd.read_csv(path_to_output_file, header=None)[0]  # prediction probabilities
threshold = 0.5   # you can adjust this threshold for your own dataset
preds_test=(probs_test>=threshold).astype('uint8') # predicted labels using the above fixed threshold

metric = load_metric("glue", "mrpc")

Link for the threshold choice problem : https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/

In [None]:
# Compute the accuracy and F1 scores
metric._compute(predictions=preds_test, references=labels_test)

{'accuracy': 0.875, 'f1': 0.911917098445596}

## Experiments' ideas

- Try other pre-trained models: https://huggingface.co/models
- Try other optimizers and learning rate schedulers
- Tune the hyperparameters : batch size, gradient accumulation parameter (iters_to_accumulate), number of epochs, learning rate
- Change the *maxlen* parameter (max : 512). If you increase it, the training will take longer
- Observe the influence of freezing the encoder weights and only updating the classifier weights
- Use other metrics (Precision, Recall, ROC AUC, Precision-recall AUC, etc.) depending on the task and the dataset


## Limitations

- As said in the BERT github repository of Google Research (https://github.com/google-research/bert), "Small sets like MRPC have a high variance in the Dev set accuracy, even when starting from the same re-training checkpoint."
So I suggest taking that into account if you want to compare models on this dataset.
- Distinct random seeds for models trained on GLUE datasets including MRPC can have a significant impact on results : for more details, you can read the paper *Fine-Tuning Pretrained Language Models: Weight Initializations, Data Orders, and Early Stopping* by Dodge et al. (https://arxiv.org/abs/2002.06305)



## Future works

- Adapt the code so other BERT-based models like RoBERTa and DistillBERT can also be fine-tuned for this task
- Experiment feeding to the classification layer the last layer hidden states' average of all input tokens or other operations with multiple encoder layers instead of the pooler output

