# Fine-tuning a Sequence Classification Model

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [2]:
def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)


tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [3]:
samples = [tokenized_datasets["train"][i] for i in range(2)]
for sample in samples:  # remove the fields that cannot be collated
    _ = sample.pop("sentence1")
    _ = sample.pop("sentence2")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n>>> {tokenizer.decode(chunk)}")
    print(f"Inputs IDs: {chunk.tolist()}")
    print(f"Number of tokens: {len(chunk)}")


>>> [CLS] amrozi accused his brother, whom he called " the witness ", of deliberately distorting his evidence. [SEP] referring to him as only " the witness ", amrozi accused his brother of deliberately distorting his evidence. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Inputs IDs: [101, 2572, 3217, 5831, 5496, 2010, 2567, 1010, 3183, 2002, 2170, 1000, 1996, 7409, 1000, 1010, 1997, 9969, 4487, 23809, 3436, 2010, 3350, 1012, 102, 7727, 2000, 2032, 2004, 2069, 1000, 1996, 7409, 1000, 1010, 2572, 3217, 5831, 5496, 2010, 2567, 1997, 9969, 4487, 23809, 3436, 2010, 3350, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Number of tokens: 59

>>> [CLS] yucaipa owned dominick's before selling the chain to safeway in 1998 for $ 2. 5 billion. [SEP] yucaipa bought dominick's in 1995 for $ 693 million and sold it to safeway for $ 1. 8 billion in 1998. [SEP]
Inputs IDs: [101, 9805, 3540, 11514, 2050, 3079, 11282, 2243, 1005, 1055, 2077, 4855, 1996, 4677, 2000, 3647, 4576, 1999, 2687, 2005, 100

In [4]:
sample = tokenized_datasets["train"][0]
print("The model will input the sentence1 and sentence2 strings tokenized together")
print(f"Sentence1: {sample['sentence1']}")
print(f"Sentence2: {sample['sentence2']}")
print(f"And expects in the batch a 'label'. In this case: {sample['label']}")

The model will input the sentence1 and sentence2 strings tokenized together
Sentence1: Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .
Sentence2: Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .
And expects in the batch a 'label'. In this case: 1


In [5]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1725
    })
})

In [6]:
# Remove the unused columns
columns_to_remove = ['sentence1', 'sentence2', 'idx']
tokenized_datasets = tokenized_datasets.remove_columns(columns_to_remove)

# Evaluation

Let’s see how we can build a useful `compute_metrics()` function and use it when we train. The function must take an `EvalPrediction` object (which is a named tuple with a `predictions` field and a `label_ids` field) and will **return a dictionary mapping strings to floats** (the strings being the names of the metrics returned, and the floats their values).

In [7]:
mockup_predictions = [1, 0, 0]
mockup_labels = [1, 0, 1]

In [8]:
import evaluate

metric = evaluate.load("glue", "mrpc")
metric.compute(
    predictions=mockup_predictions,
    references=mockup_labels
)

{'accuracy': 0.6666666666666666, 'f1': 0.6666666666666666}

# Training

The first step before we can define our `Trainer` is to define a `TrainingArguments` class that will contain all the hyperparameters the Trainer will use for training and evaluation. The only argument you have to provide is a directory where the trained model will be saved, as well as the checkpoints along the way. For all the rest, you can leave the defaults, which should work pretty well for a basic fine-tuning.

In [9]:
import evaluate

metric = evaluate.load("glue", "mrpc")
metric.compute(
    predictions=mockup_predictions,
    references=mockup_labels
)

{'accuracy': 0.6666666666666666, 'f1': 0.6666666666666666}

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

In [11]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)

In [12]:
from torch.utils.data import DataLoader

BATCH_SIZE = 8

train_dataloader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=BATCH_SIZE,
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    collate_fn=data_collator,
    batch_size=BATCH_SIZE
)

In [13]:
output_dir = f"tmp/seq_classification-{checkpoint}"

In [14]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [15]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [16]:
from tqdm.auto import tqdm
import torch

output_dir = f"tmp/ner-pt_finetuning-{checkpoint}"
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    ###############################################################
    ######################### TRAINING ############################
    ###############################################################
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    ###############################################################
    ######################### EVALUATION ##########################
    ###############################################################
    model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]

        # Necessary to pad predictions and labels for being gathered
        # We could have multiple batches on each device with different lengths
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=0)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=0)
        
        # Gather all predictions and labels
        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        metric.add_batch(predictions=predictions_gathered, references=labels_gathered)
    
    results = metric.compute()
    print(
        f"epoch {epoch}:",
        {
            key: results[key]
            for key in ["accuracy", "f1"]
        },
    )
    
    ###############################################################
    ######################### SAVE MODEL ##########################
    ###############################################################
    accelerator.wait_for_everyone()  # Make sure everyone has finished training
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)



  0%|          | 0/1377 [00:00<?, ?it/s]

epoch 0: {'accuracy': 0.8333333333333334, 'f1': 0.8781362007168458}
epoch 1: {'accuracy': 0.8553921568627451, 'f1': 0.8998302207130731}
epoch 2: {'accuracy': 0.8480392156862745, 'f1': 0.89419795221843}
