<a href="https://colab.research.google.com/github/Aikoin/Fine_Tuning_BERT_for_Text_Classification/blob/main/Fine_Tuning_BERT_for_Text_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. Introduction
BERT(Bidirectional Encoder Representations from Transformers) is a Machine Learning model based on transformers, i.e. attention components able to learn contextual relations between words.  
The Natural Language Processing (NLP) community can leverage powerful tools like BERT in (at least) two ways:  

1. **Feature-based approac**h  
  1.1 Download a pre-trained BERT model.   
  1.2 Use BERT to turn natural language sentences into a vector representation.  
  1.3 Feed the pre-trained vector representations into a model for a downstream task (such as text classification).
  
2. **Perform fine-tuning**  
  2.1 Download a pre-trained BERT model.  
  2.2 Update the model weights on the downstream task.  
  
  
  In this post, we will follow the fine-tuning approach on binary text classification example.

# 2. Environment setup

In [None]:
!pip install transformers



In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange
import random

# 3. Dataset
We use the public SMS Spam Collection Data Set⁵ from the UCI Machine Learning Repository⁶. The data consists of a text file with a set of SMS messages labeled as either spam or ham.

In [None]:
!wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'

--2023-12-22 10:16:07--  https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘smsspamcollection.zip’

smsspamcollection.z     [<=>                 ]       0  --.-KB/s               smsspamcollection.z     [ <=>                ] 198.65K  --.-KB/s    in 0.08s   

2023-12-22 10:16:07 (2.43 MB/s) - ‘smsspamcollection.zip’ saved [203415]



In [None]:
!unzip -o smsspamcollection.zip

Archive:  smsspamcollection.zip
  inflating: SMSSpamCollection       
  inflating: readme                  


In [None]:
!head -10  SMSSpamCollection

ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham	U dun say so early hor... U c already then say...
ham	Nah I don't think he goes to usf, he lives around here though
spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham	Even my brother is not like to speak with me. They treat me like aids patent.
ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam	WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
spam	H

For each line, the label is followed by a tab space and the raw text message. We choose to process the file to get a pandas.DataFrame object, as it is a common starting point in Data Science experiments:

In [None]:
file_path = '/content/SMSSpamCollection'
df = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
  for line in f.readlines():
    split = line.split('\t')
    new_row = {'label': 1 if split[0] == 'spam' else 0, 'text': split[1]}
    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
df.head()

Unnamed: 0,label,text
0,0,"Go until jurong point, crazy.. Available only ..."
1,0,Ok lar... Joking wif u oni...\n
2,1,Free entry in 2 a wkly comp to win FA Cup fina...
3,0,U dun say so early hor... U c already then say...
4,0,"Nah I don't think he goes to usf, he lives aro..."


In [None]:
text = df.text.values
labels = df.label.values

# 4. Preprocessing
We need to preprocess the text source before feeding it to BERT. To do so, we download the BertTokenizer:

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                          do_lower_case = True) # Construct a BERT tokenizer. Based on WordPiece.

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Let us observe how the tokenizer can split a random sentence into word-level tokens and map them to their respective IDs in the BERT vocabulary:

In [None]:
def print_rand_sentence():
  '''Displays the tokens and respective IDs of a random text sample'''
  index = random.randint(0, len(text) -1)
  print(text[index])
  tokens = tokenizer.tokenize(text[index])
  table = np.array([tokens,
                    tokenizer.convert_tokens_to_ids(tokens)]).T
  print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

Buy one egg for me da..please:)

╒══════════╤═════════════╕
│ Tokens   │   Token IDs │
╞══════════╪═════════════╡
│ buy      │        4965 │
├──────────┼─────────────┤
│ one      │        2028 │
├──────────┼─────────────┤
│ egg      │        8288 │
├──────────┼─────────────┤
│ for      │        2005 │
├──────────┼─────────────┤
│ me       │        2033 │
├──────────┼─────────────┤
│ da       │        4830 │
├──────────┼─────────────┤
│ .        │        1012 │
├──────────┼─────────────┤
│ .        │        1012 │
├──────────┼─────────────┤
│ please   │        3531 │
├──────────┼─────────────┤
│ :        │        1024 │
├──────────┼─────────────┤
│ )        │        1007 │
╘══════════╧═════════════╛


BERT requires the following preprocessing steps:
1. Add **special tokens**:
- [CLS]: at the beginning of each sentence (ID 101)
- [SEP]: at the end of each sentence (ID 102)
2. Make sentences of the **same length**:
- This is achieved by *padding*, i.e. adding values of convenience to shorter sequences to match the desired length. Longer sequences are truncated.
- The padding ([PAD]) tokens have ID 0.
- The maximum sequence length allowed is of 512 tokens.
3. Create an **attention mask**:
- List of 0/1 indicating whether the model should consider the tokens or not when learning their contextual representation. We expect [PAD] tokens to have value 0.

The process can be represented as follows:
![image.png](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*vaw98m1VVncgKxNFWI0d2Q.png)

We can perform all the needed steps by using the `tokenizer.encode_plus` method. When called, it returns a `transformers.tokenization.tokenization-utils_base.BatchEncoding` object with the following fields:
- input_ids: list of token IDs.
- token_type_ids: list of token type IDs.
- attention_mask: list of 0/1 indicating which tokens should be considered by the model (return_attention_mask = True).

In [None]:
token_id = []
attention_masks = []

def preprocessing(input_text, tokenizer):
  '''
  Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
    - input_ids: list of token ids
    - token_type_ids: list of token type ids
    - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
  '''
  return tokenizer.encode_plus(
      input_text,
      add_special_tokens = True,
      max_length = 32,
      padding = 'max_length',
      truncation=True,
      return_attention_mask = True,
      return_tensors = 'pt'
  )

for sample in text:
  encoding_dict = preprocessing(sample, tokenizer)
  token_id.append(encoding_dict['input_ids'])
  attention_masks.append(encoding_dict['attention_mask'])

token_id = torch.cat(token_id, dim = 0)
attention_masks = torch.cat(attention_masks, dim = 0)
labels = torch.tensor(labels)

Note: BERT is a model with absolute position embeddings, so it is usually advised to pad the inputs on the right (end of the sequence) rather than the left (beginning of the sequence).

# 5. Data split
We split the dataset into train (80%) and validation (20%) sets, and wrap them around a `torch.utils.data.DataLoader` object. With its intuitive syntax, `DataLoader` provides an iterable over the given dataset.

In [None]:
val_ratio = 0.2
# Recommended batch size: 16, 32. See: https://arxiv.org/pdf/1810.04805.pdf
batch_size = 16

# Indices of the train and validation splits stratified by labels
train_idx, val_idx = train_test_split(
    np.arange(len(labels)),
    test_size = val_ratio,
    shuffle = True,
    stratify = labels)

# Train and validation sets
train_set = TensorDataset(token_id[train_idx],
                          attention_masks[train_idx],
                          labels[train_idx])

val_set = TensorDataset(token_id[val_idx],
                        attention_masks[val_idx],
                        labels[val_idx])

# Prepare DataLoader
train_dataloader = DataLoader(
    train_set,
    sampler = RandomSampler(train_set),
    batch_size = batch_size
)

validation_dataloader = DataLoader(
    val_set,
    sampler = SequentialSampler(val_set),
    batch_size = batch_size
)

# 6. Train
It is time for the fine-tuning task:
Select hyperparameters based on the recommendations from the BERT paper:
>The optimal hyperparameter values are task-specific, but we found the following range of possible values to work well across all tasks:
- Batch size: 16, 32
- Learning rate (Adam): 5e-5, 3e-5, 2e-5
- Number of epochs: 2, 3, 4

Define some functions to assess validation metrics (accuracy, precision, recall and specificity) during the training process:
![](https://upload.cc/i1/2023/12/22/tOex6S.png)

In [None]:
def b_tp(preds, labels):
  return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
  return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
  return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
  return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
  '''
  Returns the following metrics:
    - accuracy    = (TP + TN) / N
    - precision   = TP / (TP + FP)
    - recall      = TP / (TP + FN)
    - specificity = TN / (TN + FP)
  '''
  preds = np.argmax(preds, axis = 1).flatten()
  labels = labels.flatten()
  tp = b_tp(preds, labels)
  tn = b_tn(preds, labels)
  fp = b_fp(preds, labels)
  fn = b_fn(preds, labels)
  b_accuracy = (tp + tn) / len(labels)
  b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
  b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
  b_specificity = tn / (tn + fp) if (tn + fp) > 0 else 'nan'
  return b_accuracy, b_precision, b_recall, b_specificity

Download `transformers.BertForSequenceClassification`, which is a BERT model with a linear layer for sentence classification (or regression) on top of the pooled output:

In [None]:
# Load the BertForSequenceClassification model
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2,
    output_attentions = False,
    output_hidden_states = False,
)

# Recommended learning rates (Adam): 5e-5, 3e-5, 2e-5
optimizer = torch.optim.AdamW(model.parameters(),
                              lr = 5e-5,
                              eps = 1e-08)

# Run on GPU
model.cuda()

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

Perform the training procedure:

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Recommended number of epochs: 2, 3, 4
epochs = 2

for _ in trange(epochs, desc = 'Epoch'):

  # ========== Training ==========

  # Set model to training mode
  model.train()

  # Tracking variables
  tr_loss = 0
  nb_tr_examples, nb_tr_steps = 0, 0

  for step, batch in enumerate(train_dataloader):
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch
    optimizer.zero_grad()
    # Forward pass
    train_output = model(b_input_ids,
                         token_type_ids = None,
                         attention_mask = b_input_mask,
                         labels = b_labels)
    # Backward pass
    train_output.loss.backward()
    optimizer.step()
    # Update tracking variables
    tr_loss += train_output.loss.item()
    nb_tr_examples += b_input_ids.size(0)
    nb_tr_steps += 1

  # ========== Validation ==========

  # Set model to evaluation mode
  model.eval()

  # Tracking variables
  val_accuracy = []
  val_precision = []
  val_recall = []
  val_specificity = []

  for batch in validation_dataloader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch
    with torch.no_grad():
      # Forward pass
      eval_output = model(b_input_ids,
                          token_type_ids = None,
                          attention_mask = b_input_mask)
    logits = eval_output.logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    # Calculate validation metrics
    b_accuracy, b_precision, b_recall, b_specificity = b_metrics(logits, label_ids)
    val_accuracy.append(b_accuracy)
    # Update precision only when (tp + fp) !=0; ignore nan
    if b_precision != 'nan': val_precision.append(b_precision)
    # Update recall only when (tp + fn) !=0; ignore nan
    if b_recall != 'nan': val_recall.append(b_recall)
    # Update specificity only when (tn + fp) !=0; ignore nan
    if b_specificity != 'nan': val_specificity.append(b_specificity)

  print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
  print('\t - Validation Accuracy: {:.4f}'.format(sum(val_accuracy)/len(val_accuracy)))
  print('\t - Validation Precision: {:.4f}'.format(sum(val_precision)/len(val_precision)) if len(val_precision)>0 else '\t - Validation Precision: NaN')
  print('\t - Validation Recall: {:.4f}'.format(sum(val_recall)/len(val_recall)) if len(val_recall)>0 else '\t - Validation Recall: NaN')
  print('\t - Validation Specificity: {:.4f}\n'.format(sum(val_specificity)/len(val_specificity)) if len(val_specificity)>0 else '\t - Validation Specificity: NaN')

Epoch:  50%|█████     | 1/2 [00:32<00:32, 32.64s/it]


	 - Train loss: 0.0875
	 - Validation Accuracy: 0.9830
	 - Validation Precision: 0.9299
	 - Validation Recall: 0.9358
	 - Validation Specificity: 0.9909



Epoch: 100%|██████████| 2/2 [01:04<00:00, 32.43s/it]


	 - Train loss: 0.0271
	 - Validation Accuracy: 0.9786
	 - Validation Precision: 0.8939
	 - Validation Recall: 0.9672
	 - Validation Specificity: 0.9816






# 7. Predict

In [None]:
new_sentence = 'WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.'

# We need Token IDs and Attention Mask for inference on the new sentence
test_ids = []
test_attention_mask = []

# Apply the tokenizer
encoding = preprocessing(new_sentence, tokenizer)

# Extract IDs and Attention Mask
test_ids.append(encoding['input_ids'])
test_attention_mask.append(encoding['attention_mask'])
test_ids = torch.cat(test_ids, dim = 0)
test_attention_mask = torch.cat(test_attention_mask, dim = 0)

# Forward pass, calculate logit predictions
with torch.no_grad():
  output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))

prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Ham'

print('Input Sentence: ', new_sentence)
print('Predicted Class: ', prediction)

Input Sentence:  WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
Predicted Class:  Spam


# 8. Conclusions
