<a href="https://colab.research.google.com/github/alexpod1000/SQuAD-QA/blob/main/ModelTrainExperimentalCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#%%bash
#[[ ! -e /colabtools ]] && exit  # Continue only if running on Google Colab

# Clone repository
# https://sysadmins.co.za/clone-a-private-github-repo-with-personal-access-token/
# For cloning the main branch:
#!git clone https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# For cloning the "evaluation-features" branch
#!git clone --branch evaluation-features https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# Change current working directory to match project
#%cd SQuAD-QA/
#!pwd

In [2]:
# External imports
import copy
import nltk
import numpy as np
import pandas as pd
import string
import torch

from functools import partial
from nltk.tokenize import TreebankWordTokenizer, SpaceTokenizer
from transformers import AutoTokenizer
from typing import Tuple, List, Dict, Any, Union

# Project imports
from squad_data.parser import SquadFileParser
from squad_data.utils import build_mappers_and_dataframe_bert
from evaluation.evaluation_metrics import Evaluator
from evaluation.utils import build_evaluation_dict_bert
from utils import split_dataframe

### Parse the json and get the data

In [3]:
train_parser = SquadFileParser("squad_data/data/training_set.json")
test_parser = SquadFileParser("squad_data/data/dev-v1.1.json")

train_data = train_parser.parse_documents()
test_data = test_parser.parse_documents()

########################### DEBUG
# reduce size for faster testing
#full_data = data
#data = []
#for i in range(1): # use only the first 1 documents
#  data.append(full_data[i])

### Prepare the mappers and datafram

In [4]:
def bert_tokenizer_fn(question, paragraph, tokenizer, max_length=384, doc_stride=128):
    pad_on_right = tokenizer.padding_side == "right"
    # Process the sample
    tokenized_input_pair = tokenizer(
        question,
        paragraph,
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    return tokenized_input_pair

In [5]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer_fn_preprocess = partial(bert_tokenizer_fn, tokenizer=tokenizer, max_length=380)
tokenizer_fn_train = partial(bert_tokenizer_fn, tokenizer=tokenizer, max_length=384)

In [6]:
paragraphs_mapper, df = build_mappers_and_dataframe_bert(tokenizer, tokenizer_fn_preprocess, train_data, limit_answers=1)
print(paragraphs_mapper[next(iter(paragraphs_mapper))])
df.head()

architecturally, the school has a catholic character. atop the main building's gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary.


Unnamed: 0,doc_id,paragraph_id,question_id,answer_id,answer_start,answer_text,question_text,tokenizer_answer_start,tokenizer_answer_end
0,0,0_0,5733be284776f41900661182,0,515,Saint Bernadette Soubirous,To whom did the Virgin Mary allegedly appear i...,130,138
1,0,0_0,5733be284776f4190066117f,0,188,a copper statue of Christ,What is in front of the Notre Dame Main Building?,52,57
2,0,0_0,5733be284776f41900661180,0,279,the Main Building,The Basilica of the Sacred heart at Notre Dame...,81,84
3,0,0_0,5733be284776f41900661181,0,381,a Marian place of prayer and reflection,What is the Grotto at Notre Dame?,95,102
4,0,0_0,5733be284776f4190066117e,0,92,a golden statue of the Virgin Mary,What sits on top of the Main Building at Notre...,33,40


In [7]:
df_train, df_val = split_dataframe(df, train_ratio=0.7)

In [8]:
print(f"Total samples: {len(df)}, Train samples: {len(df_train)}, Validation samples: {len(df_val)}")

Total samples: 88552, Train samples: 61550, Validation samples: 27002


### DataConverter and CustomQADataset

In [9]:
from data_loading.utils import bert_padder_collate_fn
from data_loading.qa_dataset import CustomQADatasetBERT

datasetQA = CustomQADatasetBERT(tokenizer_fn_train, df_train, paragraphs_mapper)
data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = bert_padder_collate_fn, batch_size=10, shuffle=True)

test_batch = next(iter(data_loader))
print(test_batch["input_ids"].shape)
print(test_batch["y_gt"].shape)

torch.Size([10, 384])
torch.Size([10, 2])


In [10]:
"""
NOTE: this logic is used for sample creation only, such that each sample is "short enough" for BERT; 
      a duplicate of this logic will need to be used in QADataset Dataloader class when we'll take
      short samples' text, tokenize them again, and find the correct index
ALTERNATIVE: for BERT models we could directly get the answer spans, and pass them in dataframe to another QADataset
             built specifically for BERT, that will just take the data from dataframe (way nicer and faster solution).
SUGGESTION: we could also use specific dict keys and in QADataset pick stuff from these keys: 
                - if these keys are absent then don't use BERT logic (eg span_start and span_end) and use previous logic
                - if these keys are present, then just use them and gather the BERT samples.
                Call these keys like "tokenizer_span_idx" (to make them kinda unique)
"""

'\nNOTE: this logic is used for sample creation only, such that each sample is "short enough" for BERT; \n      a duplicate of this logic will need to be used in QADataset Dataloader class when we\'ll take\n      short samples\' text, tokenize them again, and find the correct index\nALTERNATIVE: for BERT models we could directly get the answer spans, and pass them in dataframe to another QADataset\n             built specifically for BERT, that will just take the data from dataframe (way nicer and faster solution).\nSUGGESTION: we could also use specific dict keys and in QADataset pick stuff from these keys: \n                - if these keys are absent then don\'t use BERT logic (eg span_start and span_end) and use previous logic\n                - if these keys are present, then just use them and gather the BERT samples.\n                Call these keys like "tokenizer_span_idx" (to make them kinda unique)\n'

# Model train

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers

from timeit import default_timer as timer
from tqdm import tqdm
from transformers.optimization import AdamW

from models.utils import SpanExtractor

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device is {device}")

The device is cuda


Model:

(input_ids, attention_mask) -> (answer_start, answer_end) // for each token in input_ids

In [13]:
def train_step(model, optimizer, loss_function, dataloader, device="cpu", show_progress=False):
    acc_loss = 0
    acc_start_accuracy = 0
    acc_end_accuracy = 0
    count = 0

    time_start = timer()
    
    model.train()
    wrapped_dataloader = tqdm(dataloader) if show_progress else dataloader
    for batch in wrapped_dataloader:
        # NOTE: we'll pass directly the batch dict to the model for inputs.
        answer_spans_start = batch["y_gt"][:, 0]
        answer_spans_end = batch["y_gt"][:, 1]
        # Clear gradients
        model.zero_grad()
        # Place to right device
        answer_spans_start = answer_spans_start.to(device)
        answer_spans_end = answer_spans_end.to(device)
        # Run forward pass
        pred_answer_start_scores, pred_answer_end_scores = model(batch)
        # Compute the CrossEntropyLoss
        loss = loss_function(pred_answer_start_scores, answer_spans_start) + loss_function(pred_answer_end_scores, answer_spans_end)
        # Compute gradients
        loss.backward()
        # Optimizer step
        optimizer.step()
        # --- Compute metrics ---
        # Get span indexes
        pred_span_start_idxs, pred_span_end_idxs = SpanExtractor.extract_most_probable(pred_answer_start_scores, pred_answer_end_scores)
        gt_start_idxs = answer_spans_start.cpu().detach()
        gt_end_idxs = answer_spans_end.cpu().detach()
        # two accs
        start_accuracy = torch.sum(gt_start_idxs == pred_span_start_idxs) / len(pred_span_start_idxs)
        end_accuracy = torch.sum(gt_end_idxs == pred_span_end_idxs) / len(pred_span_end_idxs)
        # Gather stats
        acc_loss += loss.item()
        acc_start_accuracy += start_accuracy.item()
        acc_end_accuracy += end_accuracy.item()
        count += 1
    time_end = timer()
    return {
        "loss": acc_loss / count, 
        "accuracy_start": acc_start_accuracy / count, 
        "accuracy_end": acc_end_accuracy / count,
        "time": time_end - time_start
    }

In [14]:
# create Evaluator object
evaluator = Evaluator(documents_list=train_data)

In [15]:
def evaluate_model_on_data(model, evaluator, dataloader, paragraphs_mapper, tokenizer, device, debug=False, show_progress=False):
    eval_dict = build_evaluation_dict_bert(model, dataloader, paragraphs_mapper, tokenizer, device, show_progress)
    if debug:
        print(f"DEBUG: Eval_dict: {eval_dict}")
    stats = {}
    stats['exact_match'] = evaluator.ExactMatch(eval_dict)
    stats['f1'] = evaluator.F1(eval_dict)
    return stats

In [16]:
class DistilBertBaseQA(torch.nn.Module):

    def __init__(self, hidden_size, num_labels):
        super(DistilBertBaseQA, self).__init__()
        self.hidden_size = hidden_size
        self.num_labels = num_labels
        self.config = transformers.DistilBertConfig(max_position_embeddings=384)
        #self.bert = transformers.DistilBertModel(bert_config)
        self.bert = transformers.DistilBertModel.from_pretrained('distilbert-base-uncased')#(bert_config)
        self.qa_outputs = torch.nn.Linear(self.hidden_size, self.num_labels)

    def forward(self, inputs):
        # --- 1) Extract data from inputs dictionary and put it on right device
        curr_device = self.bert.device
        input_ids = inputs["input_ids"].to(curr_device)
        attention_mask = inputs["attention_mask"].to(curr_device)
        # --- 2) Run BERT backbone to produce final representation
        output = self.bert(input_ids = input_ids, attention_mask = attention_mask)
        # --- 3) On top of the final representation, run a mapper to get scores for each position.
        sequence_output = output[0]   #(None, seq_len, hidden_size)
        logits = self.qa_outputs(sequence_output) #(None, seq_len, hidden_size)*(hidden_size, 2)=(None, seq_len, 2)
        start_logits, end_logits = logits.split(1, dim=-1)    #(None, seq_len, 1), (None, seq_len, 1)
        start_logits = start_logits.squeeze(-1)  #(None, seq_len)
        end_logits = end_logits.squeeze(-1)    #(None, seq_len)
        # --- 4) Prepare output tuple
        outputs = (start_logits, end_logits,) 
        return outputs

In [17]:
# Define baseline model
model = DistilBertBaseQA(768, 2).to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=0.00001, correct_bias=False)

In [18]:
dataset_train_QA = CustomQADatasetBERT(tokenizer_fn_train, df_train, paragraphs_mapper)
dataset_val_QA = CustomQADatasetBERT(tokenizer_fn_train, df_val, paragraphs_mapper)

In [19]:
train_data_loader = torch.utils.data.DataLoader(dataset_train_QA, collate_fn = bert_padder_collate_fn, batch_size=16, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(dataset_val_QA, collate_fn = bert_padder_collate_fn, batch_size=16, shuffle=True)

In [20]:
history = {"train_loss": [], "train_acc_start": [], "train_acc_end": []}
loop_start = timer()
# lr scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, threshold=0.01)
for epoch in range(50):
    train_dict = train_step(model, optimizer, loss_function, train_data_loader, device=device, show_progress=True)
    eval_results = evaluate_model_on_data(model, evaluator, val_data_loader, paragraphs_mapper, tokenizer, device, debug=False, show_progress=True)
    cur_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch: {epoch}, lr: {cur_lr}, Train loss: {train_dict["loss"]:.4f},  Train acc start: {train_dict["accuracy_start"]:.4f}, Train acc end: {train_dict["accuracy_end"]:.4f}, Time: {train_dict["time"]:.4f}')
    history["train_loss"].append(train_dict["loss"]);history["train_acc_start"].append(train_dict["accuracy_start"]);history["train_acc_end"].append(train_dict["accuracy_end"]);
    #history["val_loss"].append(val_dict["loss"]);history["val_acc"].append(val_dict["accuracy"]);
    #scheduler.step(val_dict["loss"])
    print(f"Evaluation Results: {eval_results}")
loop_end = timer()
print(f"Elapsed time: {(loop_end - loop_start):.4f}")

  4%|▎         | 142/3847 [00:54<23:44,  2.60it/s]


KeyboardInterrupt: 

## Simple qualitative evaluation

In [20]:
def get_answer_span_helper(context, question, model, tokenizer_fn, tokenizer, device="cpu"):
    tokenized_input = tokenizer_fn(question, context)
    output_span = model({
        "input_ids": torch.tensor(tokenized_input["input_ids"]).to(device), 
        "attention_mask": torch.tensor(tokenized_input["attention_mask"]).to(device)
    })
    start, end = SpanExtractor.extract_most_probable(output_span[0], output_span[1])
    start = start.item()
    end = end.item()
    return tokenizer.decode(tokenized_input["input_ids"][0][start:end], skip_special_tokens=True)

In [21]:
context = "This is a test message, written to see if our model can correctly predict its outputs."
question = "Who needs to predict its outputs?"
pred_answer = get_answer_span_helper(context, question, model, tokenizer_fn_train, tokenizer, device="cuda")
print(pred_answer)

our model
