In [1]:
import json
import transformers
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Running on {device}')

data_path = "./data"
model_path = "./models"

  from .autonotebook import tqdm as notebook_tqdm


Running on cuda:0


### Load the dataset back

In [2]:
def load_from_jsonlines(filename):
    with open(filename, 'r', encoding='utf-8') as json_file:
        dataset = [json.loads(line) for line in json_file]
    return dataset

### Import model: BERT<sub>size</sub>
*   Choose size from:
    *   Tiny
    *   Mini
    *   Small
    *   Medium

In [3]:
model_size = "tiny"
# model_size = "mini"
# model_size = "small"
# model_size = "medium"

In [4]:
from transformers import BertTokenizerFast, AutoModel

tokenizer = BertTokenizerFast.from_pretrained('csarron/bert-base-uncased-squad-v1')
model = AutoModel.from_pretrained(f"prajjwal1/bert-{model_size}")

*   Take a look at the structure

In [5]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


*   Take a look at input requirements

In [6]:
# help(model.forward)

*   Create a wrapper for SQuAD

In [7]:
class BERTForSQuAD(nn.Module):
    def __init__(self, bert_model=None):
        super(BERTForSQuAD, self).__init__()
        self.bert = bert_model
        self.qa_outputs = None
        if bert_model:
            self.qa_outputs = nn.Linear(bert_model.config.hidden_size, 2, bias=True)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        last_hidden_state = outputs.last_hidden_state  # Shape: [batch_size, sequence_length, hidden_size]

        logits = self.qa_outputs(last_hidden_state)  # Shape: [batch_size, sequence_length, 2]
        start_logits, end_logits = logits.split(1, dim=-1)  # Split into start and end logits
        start_logits = start_logits.squeeze(-1)  # Shape: [batch_size, sequence_length]
        end_logits = end_logits.squeeze(-1)      # Shape: [batch_size, sequence_length]

        return start_logits, end_logits
    
    def save(self, save_path):
        self.bert.save_pretrained(f"{save_path}/bert-pod")
        torch.save(self.qa_outputs.state_dict(), f"{save_path}/linear_adapter.pth")
        torch.save(self.state_dict(), f"{save_path}/full_model.pth")
    
    def load(self, load_path):
        self.bert = AutoModel.from_pretrained(f"{load_path}/bert-pod")
        self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2, bias=True)
        self.qa_outputs.load_state_dict(torch.load(f"{load_path}/linear_adapter.pth", weights_only=True))
        self.load_state_dict(torch.load(f"{load_path}/full_model.pth", weights_only=True))


In [8]:
bert_squad = BERTForSQuAD(model)

In [9]:
total_params = sum(p.numel() for p in bert_squad.parameters())
print(f'BERT-tiny classifier contains {total_params} parameters.')

BERT-tiny classifier contains 4386178 parameters.


### Prepare training data

*   Retrieve data in the form of lists

In [10]:
def retrieve_lists(squad_train):
    question_list = []
    context_list = []
    answer_start_list = []
    answer_end_list = []

    for squad_instance in squad_train:
        question_list.append(squad_instance["question"])
        context_list.append(squad_instance["context"])

        answer_start = squad_instance["answers"]["answer_start"][0]
        answer_end = answer_start + len(squad_instance["answers"]["text"][0])
        
        answer_start_list.append(answer_start)
        answer_end_list.append(answer_end)

    return question_list, context_list, answer_start_list, answer_end_list

*   Tokenize inputs and convert the starting index of characters to tokens

In [11]:
def tokenize_squad_data(squad_train, tokenizer):
    question_list, context_list, answer_start_list, answer_end_list = retrieve_lists(squad_train)
    # Merging the question and answers together
    tokenized_data = tokenizer(question_list, context_list, padding = True, truncation = True)

    start_token_list = []
    end_token_list = []

    for i in range(len(squad_train)):
    # Skip the question section by setting `sequence_index`=1
        start_token = tokenized_data.char_to_token(i, answer_start_list[i], 1)
        end_token = tokenized_data.char_to_token(i, answer_end_list[i] - 1, 1)

        # if return is None, the answer passage containing the answer is truncated.
        if start_token is None:
            start_token_list.append(tokenizer.model_max_length-1)
        else:
            start_token_list.append(start_token)
        if end_token is None:
            end_token_list.append(tokenizer.model_max_length-1)
        else:
            end_token_list.append(end_token)
        

    tokenized_data.update({'answer_start':start_token_list, 'answer_end':end_token_list})
    return tokenized_data

*   Create a `Dataset` and convert it to a `DataLoader`

In [12]:
class squad_dataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_data):
        self.tokenized_data = tokenized_data

    def __getitem__(self, idx):
        tokenized_index_data = {}

        for key, val in self.tokenized_data.items():
            tokenized_index_data.update({key: torch.tensor(val[idx])})

        return tokenized_index_data
    
    def __len__(self):
        return len(self.tokenized_data.answer_start)

*   Create a function to create dataloaders from filename

In [13]:
def dataloader_from_filename(filename, batch_size):
    squad_jsonl = load_from_jsonlines(filename)
    tokenized = tokenize_squad_data(squad_jsonl, tokenizer)
    dataset = squad_dataset(tokenized)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

### Prepare for training

*   Compute loss and other metrics

In [14]:
def compute_loss(start_logits, end_logits, answer_start, answer_end):
    loss_fn = nn.CrossEntropyLoss()
    start_loss = loss_fn(start_logits, answer_start)
    end_loss = loss_fn(end_logits, answer_end)
    return (start_loss + end_loss) / 2

In [15]:
def compute_metric(start_logits, end_logits, answer_start, answer_end):
    batch_size = start_logits.shape[0]
    
    pred_start_list = torch.argmax(start_logits, dim=1)
    pred_end_list = torch.argmax(end_logits, dim=1)
    
    F1_list = []
    EM_list = []

    for i in range(batch_size):
        pred_start, pred_end = pred_start_list[i].item(), pred_end_list[i].item()
        true_start, true_end = answer_start[i].item(), answer_end[i].item()

        # Make sure the answer is valid
        if pred_start > pred_end:
            pred_end = pred_start

        # Calculate F1-score
        pred_tokens = set(range(pred_start, pred_end + 1))
        true_tokens = set(range(true_start, true_end + 1))

        common_tokens = pred_tokens.intersection(true_tokens)

        if len(common_tokens) == 0:
            f1 = 0.0
        else:
            precision = len(common_tokens) / len(pred_tokens)
            recall = len(common_tokens) / len(true_tokens)
            f1 = (2 * precision * recall) / (precision + recall)

        F1_list.append(f1)

        # Calculate Exact Match (EM)
        if (pred_start == true_start and pred_end == true_end):
            EM_list.append(1)
        else:
            EM_list.append(0)

    # Sum and average later
    return F1_list, EM_list

*   Write a function for validation

In [16]:
def valid_model(model, valid_loader):
    model.eval()
    loss_list = []
    f1_list = []
    em_list = []

    for batch in valid_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        answer_start = batch['answer_start'].to(device)
        answer_end = batch['answer_end'].to(device)

        start_logits, end_logits = model(input_ids, 
                                         attention_mask=attention_mask, 
                                         token_type_ids=token_type_ids)
        
        loss = compute_loss(start_logits, end_logits, answer_start, answer_end)
        f1_sublist, em_sublist = compute_metric(start_logits, end_logits, answer_start, answer_end)

        loss_list.append(loss.item())
        f1_list.extend(f1_sublist)
        em_list.extend(em_sublist)
    
    loss = sum(loss_list) / len(loss_list)
    f1 = sum(f1_list) / len(f1_list) * 100
    em = sum(em_list) / len(em_list) * 100

    return loss, f1, em

In [17]:
def torch_train(model, optimizer, train_loader_list, train_schedule, valid_loader, epochs=1, save_path = f"{model_path}model"):
    model.train()
    model.to(device)
    last_valid_loss = torch.inf
    model_to_save = copy.deepcopy(model)
    model_to_save.load_state_dict(model.state_dict())
    for i in range(epochs):
        train_idx = train_schedule[i % len(train_schedule)]
        loop = train_loader_list[train_idx]
        loss_list = []
        for batch in tqdm(loop):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            answer_start = batch['answer_start'].to(device)
            answer_end = batch['answer_end'].to(device)

            start_logits, end_logits = model(input_ids, 
                                             attention_mask=attention_mask, 
                                             token_type_ids=token_type_ids)

            loss = compute_loss(start_logits, end_logits, answer_start, answer_end)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_list.append(loss.item())

            

        train_loss = sum(loss_list) / len(loss_list)
        valid_loss, f1, em = valid_model(model, valid_loader)
        print(f"Epoch {i+1}, train_loss: {train_loss:.3f}, valid_loss: {valid_loss:.3f}, EM: {em:.3f}, F1: {f1:.3f}")

        if loss < last_valid_loss:
            model_to_save = copy.deepcopy(model)
            model_to_save.load_state_dict(model.state_dict())
    
    # Save model
    model_to_save.save(save_path)

*   Set up learning rates and data

In [18]:
valid_loader = dataloader_from_filename(f"{data_path}/squad_valid.jsonl", batch_size=64)

In [19]:
model_vanilla = AutoModel.from_pretrained(f"prajjwal1/bert-{model_size}")
bert_squad = BERTForSQuAD(model_vanilla)
train_loader_list = [dataloader_from_filename(f"{data_path}/squad_train_vanilla.jsonl", 
                                              batch_size=4)]
train_schedule = [0] # indices in the train_loader_list

learning_rate = 5e-5
adam = torch.optim.Adam(bert_squad.parameters(), lr=learning_rate)

In [20]:
torch_train(model=bert_squad, 
            optimizer=adam, 
            train_loader_list=train_loader_list,
            train_schedule=train_schedule,
            valid_loader=valid_loader,
            epochs=10,
            save_path=f"{model_path}/vanilla_finetuning")

100%|██████████| 21900/21900 [04:26<00:00, 82.20it/s]


Epoch 1, train_loss: 3.057, valid_loss: 2.448, EM: 22.564, F1: 37.036


100%|██████████| 21900/21900 [04:30<00:00, 80.93it/s]


Epoch 2, train_loss: 2.063, valid_loss: 2.238, EM: 26.405, F1: 42.300


100%|██████████| 21900/21900 [04:31<00:00, 80.73it/s]


Epoch 3, train_loss: 1.680, valid_loss: 2.260, EM: 26.868, F1: 43.562


100%|██████████| 21900/21900 [04:31<00:00, 80.62it/s]


Epoch 4, train_loss: 1.376, valid_loss: 2.434, EM: 24.948, F1: 41.523


100%|██████████| 21900/21900 [04:35<00:00, 79.55it/s]


Epoch 5, train_loss: 1.102, valid_loss: 2.687, EM: 24.210, F1: 40.718


100%|██████████| 21900/21900 [04:37<00:00, 78.80it/s]


Epoch 6, train_loss: 0.853, valid_loss: 2.993, EM: 23.406, F1: 39.397


100%|██████████| 21900/21900 [04:40<00:00, 78.02it/s]


Epoch 7, train_loss: 0.638, valid_loss: 3.594, EM: 21.485, F1: 37.283


100%|██████████| 21900/21900 [04:40<00:00, 78.08it/s]


Epoch 8, train_loss: 0.467, valid_loss: 4.308, EM: 20.908, F1: 36.407


100%|██████████| 21900/21900 [04:37<00:00, 78.92it/s]


Epoch 9, train_loss: 0.341, valid_loss: 4.745, EM: 20.009, F1: 35.564


100%|██████████| 21900/21900 [04:39<00:00, 78.37it/s]


Epoch 10, train_loss: 0.257, valid_loss: 5.307, EM: 18.675, F1: 34.513
