# BERT Fine Tuning for Semantic Question Matching

In [1]:
# !pip install transformers
# !pip install torchdata

In [2]:
import os
import time
import torch
import numpy as np
import torch.nn as nn
# import multiprocessing
from torch.optim import Adam
from torchtext.datasets import QQP
# from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from transformers import BertTokenizerFast
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import random_split
from transformers import BertForSequenceClassification
from torchtext.data.functional import to_map_style_dataset

### Preprocessing

In [3]:
data_iter = QQP()
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)

In [4]:
data = to_map_style_dataset(data_iter)
split_ratio = int(len(data) * 0.95)
data, test_data = random_split(data, [split_ratio, len(data) - split_ratio])
split_ratio = int(len(data) * 0.95)
train_data, valid_data = random_split(data, [split_ratio, len(data) - split_ratio])

In [5]:
EPOCHS = 10
LR = 2e-5
BATCH_SIZE = 64
NUM_CLASSES = len(set([label for (label, _, _) in train_data]))
MAX_LEN = 100
EMBED_SIZE = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def collate_fn(batch): 
  labels, input_ids, attention_mask = [], [], []
  for _label, _q1, _q2 in batch:
    encoded_dict = tokenizer.encode_plus(
        _q1, _q2, padding='max_length', max_length=MAX_LEN, truncation=True
    )
    labels.append(_label)
    input_ids.append(encoded_dict['input_ids'])
    # token_type_ids.append(encoded_dict['token_type_ids'])
    attention_mask.append(encoded_dict['attention_mask'])
  
  labels = torch.tensor(labels).to(device)
  input_ids = torch.tensor(input_ids).to(device)
  # token_type_ids = torch.tensor(token_type_ids).to(device)
  attention_mask = torch.tensor(attention_mask).to(device)
  return labels, input_ids, attention_mask

In [7]:
# n_cores = multiprocessing.cpu_count()

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

### Model Definition

In [8]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model = model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [9]:
# criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)

### Model Training

In [10]:
def train(dataloader):
  model.train()
  total_loss, total_acc, total_count = 0, 0, 0
  for labels, input_ids, attention_mask in dataloader:
    optimizer.zero_grad()
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs.loss
    loss.backward()
    clip_grad_norm_(model.parameters(), 0.1)
    optimizer.step()
    total_loss += loss.item()
    total_count += 1
  return total_loss/total_count

In [11]:
def evaluate(dataloader):
  model.eval()
  total_loss, total_acc, total_count = 0, 0, 0
  for labels, input_ids, attention_mask in dataloader:
    with torch.no_grad():
      loss, logits = model(input_ids, attention_mask=attention_mask, labels=labels)
    y_pred = np.argmax(logits.detach().numpy(), axis=1).flatten()
    total_loss += loss.item()
    total_acc += (y_pred == labels).sum().item()
    total_count += 1
  return total_loss/total_count, total_acc/total_count

In [None]:
if not os.path.exists('./../models'):
  os.mkdir('./../models')

best_valid_loss = float('inf')
for epoch in range(EPOCHS):
  start_time = time.time()
  train_loss = train(train_dataloader)
  val_loss, val_acc = evaluate(valid_dataloader)
  if val_loss < best_valid_loss:
    best_valid_loss = val_loss
    torch.save(model.state_dict(), './../models/sqm-bert.pt')
  print(f'Epoch: {epoch+1:02} | Time: {time.time()-start_time} | Train Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}')

### Model Testing

In [None]:
test_loss, test_acc = evaluate(test_dataloader)
print(f'Test Loss: {test_loss:.3f} | Test Accuracy: {test_acc:.3f}')

### References

- [BERT Fine-Tuning Tutorial with PyTorch](https://mccormickml.com/2019/07/22/BERT-fine-tuning/)
- [Training and fine-tuning](https://huggingface.co/transformers/v3.3.1/training.html)