<a href="https://colab.research.google.com/github/alexpod1000/SQuAD-QA/blob/main/Train_model_GPU_DRQA_experimental.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Run the following cells only if using Colab
if 'google.colab' in str(get_ipython()):
    # Clone repository
    !git clone https://github.com/alexpod1000/SQuAD-QA.git
    # Change current working directory to match project
    %cd SQuAD-QA/
    !pwd

In [2]:
# External imports
import copy
import nltk
import numpy as np
import pandas as pd
import string
import torch
import json

from nltk.tokenize import TreebankWordTokenizer, SpaceTokenizer
from typing import Tuple, List, Dict, Any, Union

# Project imports
from squad_data.parser import SquadFileParser
from squad_data.utils import build_mappers_and_dataframe, add_paragraphs_spans
from evaluation.evaluate import evaluate_predictions
from evaluation.utils import build_evaluation_dict
from utils import split_dataframe

### Download Embedding

In [3]:
from utils.embedding_utils import EmbeddingDownloader

embedding_downloader = EmbeddingDownloader(
    "embedding_models", 
    "embedding_model.kv", 
    model_name="fasttext-wiki-news-subwords-300"
)

embedding_model = embedding_downloader.load()

Loading pre-downloaded embeddings from /home/alexpod/uni/magistrale_ai/secondo_anno/nlp/project/SQuAD-QA/embedding_models/embedding_model.kv
End!
Embedding dimension: 300


### Parse the json and get the data

In [4]:
train_file_json = "squad_data/data/training_set.json"
test_file_json = "squad_data/data/dev-v1.1.json"

train_parser = SquadFileParser(train_file_json)
test_parser = SquadFileParser(test_file_json)

train_data = train_parser.parse_documents()
test_data = test_parser.parse_documents()

### Prepare the mappers and datafram

In [5]:
paragraphs_mapper, df = build_mappers_and_dataframe(train_data, limit_answers=1)
print(paragraphs_mapper[next(iter(paragraphs_mapper))])
df.head()

Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


Unnamed: 0,doc_id,paragraph_id,question_id,answer_id,answer_start,answer_text,question_text
0,0,0_0,5733be284776f41900661182,0,515,Saint Bernadette Soubirous,To whom did the Virgin Mary allegedly appear i...
1,0,0_0,5733be284776f4190066117f,0,188,a copper statue of Christ,What is in front of the Notre Dame Main Building?
2,0,0_0,5733be284776f41900661180,0,279,the Main Building,The Basilica of the Sacred heart at Notre Dame...
3,0,0_0,5733be284776f41900661181,0,381,a Marian place of prayer and reflection,What is the Grotto at Notre Dame?
4,0,0_0,5733be284776f4190066117e,0,92,a golden statue of the Virgin Mary,What sits on top of the Main Building at Notre...


In [6]:
def preprocess_text(text_dict: Dict[str, Any], text_key: Union[str, None] = None) -> Any:
    text_dict = copy.deepcopy(text_dict)
    # just tokenize and remove punctuation for now
    tokenizer = SpaceTokenizer()#TreebankWordTokenizer()
    for key in text_dict.keys():
        if text_key is not None:
            text = tokenizer.tokenize(text_dict[key][text_key])
            text_dict[key][text_key] = text
        else:
            text = tokenizer.tokenize(text_dict[key])
            text_dict[key] = text
    return text_dict

In [7]:
paragraphs_mapper = preprocess_text(paragraphs_mapper)
df['question_text'] = df.apply(lambda row: nltk.word_tokenize(row['question_text']), axis=1)

In [8]:
# Extend the paragraphs mapper to include spans
paragraphs_spans_mapper = add_paragraphs_spans(paragraphs_mapper)

In [9]:
print(paragraphs_spans_mapper['0_0']['text'])
print(paragraphs_spans_mapper['0_0']['spans'])

['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the', 'Main', "Building's", 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'Virgin', 'Mary.', 'Immediately', 'in', 'front', 'of', 'the', 'Main', 'Building', 'and', 'facing', 'it,', 'is', 'a', 'copper', 'statue', 'of', 'Christ', 'with', 'arms', 'upraised', 'with', 'the', 'legend', '"Venite', 'Ad', 'Me', 'Omnes".', 'Next', 'to', 'the', 'Main', 'Building', 'is', 'the', 'Basilica', 'of', 'the', 'Sacred', 'Heart.', 'Immediately', 'behind', 'the', 'basilica', 'is', 'the', 'Grotto,', 'a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflection.', 'It', 'is', 'a', 'replica', 'of', 'the', 'grotto', 'at', 'Lourdes,', 'France', 'where', 'the', 'Virgin', 'Mary', 'reputedly', 'appeared', 'to', 'Saint', 'Bernadette', 'Soubirous', 'in', '1858.', 'At', 'the', 'end', 'of', 'the', 'main', 'drive', '(and', 'in', 'a', 'direct', 'line', 'that', 'connects', 'through', '3', 'statues', 'and', 'the', 'Gold', 'Dome),

In [10]:
df_train, df_val = split_dataframe(df, train_ratio=0.7)

In [11]:
print(f"Total samples: {len(df)}, Train samples: {len(df_train)}, Validation samples: {len(df_val)}")

Total samples: 87599, Train samples: 60876, Validation samples: 26723


### DataConverter and CustomQADataset

In [12]:
from data_loading.utils import DataConverter, padder_collate_fn
from data_loading.qa_dataset import CustomQADataset

data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
datasetQA = CustomQADataset(data_converter, df_train, paragraphs_mapper)
data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = padder_collate_fn, batch_size=10, shuffle=True)

test_batch = next(iter(data_loader))
print(test_batch["paragraph_emb"].shape)
print(test_batch["y_gt"].shape)

torch.Size([10, 162, 300])
torch.Size([10, 2])


# Model train

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from timeit import default_timer as timer
from tqdm import tqdm

from models.utils import SpanExtractor

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device is {device}")

The device is cuda


Model:

(paragraph_emb, question_emb) -> (answer_start, answer_end) // for each token in paragraph_emb

In [15]:
def train_step(model, optimizer, loss_function, dataloader, device="cpu", show_progress=False):
    acc_loss = 0
    acc_start_accuracy = 0
    acc_end_accuracy = 0
    count = 0

    time_start = timer()
    
    model.train()
    wrapped_dataloader = tqdm(dataloader) if show_progress else dataloader
    for batch in wrapped_dataloader:
        answer_spans_start = batch["y_gt"][:, 0]
        answer_spans_end = batch["y_gt"][:, 1]
        # Clear gradients
        model.zero_grad()
        # Place to right device
        answer_spans_start = answer_spans_start.to(device)
        answer_spans_end = answer_spans_end.to(device)
        # Run forward pass
        pred_answer_start_scores, pred_answer_end_scores = model(batch)
        # Compute the CrossEntropyLoss
        loss = loss_function(pred_answer_start_scores, answer_spans_start) + loss_function(pred_answer_end_scores, answer_spans_end)
        # Compute gradients
        loss.backward()
        # Optimizer step
        optimizer.step()
        # --- Compute metrics ---
        # Get span indexes
        pred_span_start_idxs, pred_span_end_idxs = SpanExtractor.extract_most_probable(pred_answer_start_scores, pred_answer_end_scores)
        gt_start_idxs = answer_spans_start.cpu().detach()
        gt_end_idxs = answer_spans_end.cpu().detach()
        # two accs
        start_accuracy = torch.sum(gt_start_idxs == pred_span_start_idxs) / len(pred_span_start_idxs)
        end_accuracy = torch.sum(gt_end_idxs == pred_span_end_idxs) / len(pred_span_end_idxs)
        # Gather stats
        acc_loss += loss.item()
        acc_start_accuracy += start_accuracy.item()
        acc_end_accuracy += end_accuracy.item()
        count += 1
    time_end = timer()
    return {
        "loss": acc_loss / count, 
        "accuracy_start": acc_start_accuracy / count, 
        "accuracy_end": acc_end_accuracy / count,
        "time": time_end - time_start
    }

In [16]:
def validation_step(model, loss_function, dataloader, device="cpu", show_progress=False):
    acc_loss = 0
    acc_start_accuracy = 0
    acc_end_accuracy = 0
    count = 0

    time_start = timer()
    wrapped_dataloader = tqdm(dataloader) if show_progress else dataloader
    
    model.eval()
    with torch.no_grad():
        for batch in wrapped_dataloader:
            answer_spans_start = batch["y_gt"][:, 0]
            answer_spans_end = batch["y_gt"][:, 1]
            # Place to right device
            answer_spans_start = answer_spans_start.to(device)
            answer_spans_end = answer_spans_end.to(device)
            # Run forward pass
            pred_answer_start_scores, pred_answer_end_scores = model(batch)
            # Compute the CrossEntropyLoss
            loss = loss_function(pred_answer_start_scores, answer_spans_start) + loss_function(pred_answer_end_scores, answer_spans_end)
            # --- Compute metrics ---
            # Get span indexes
            pred_span_start_idxs, pred_span_end_idxs = SpanExtractor.extract_most_probable(pred_answer_start_scores, pred_answer_end_scores)
            gt_start_idxs = answer_spans_start.cpu().detach()
            gt_end_idxs = answer_spans_end.cpu().detach()
            # two accs
            start_accuracy = torch.sum(gt_start_idxs == pred_span_start_idxs) / len(pred_span_start_idxs)
            end_accuracy = torch.sum(gt_end_idxs == pred_span_end_idxs) / len(pred_span_end_idxs)
            # Gather stats
            acc_loss += loss.item()
            acc_start_accuracy += start_accuracy.item()
            acc_end_accuracy += end_accuracy.item()
            count += 1
    time_end = timer()
    return {
        "loss": acc_loss / count, 
        "accuracy_start": acc_start_accuracy / count, 
        "accuracy_end": acc_end_accuracy / count,
        "time": time_end - time_start
    }

In [18]:
class WeightedSum(nn.Module):
    def __init__(self, input_dim):
        """
        General idea, given a random dummy weights vector, 
        learn to weight it based on query
        """
        super(WeightedSum, self).__init__()
        self.weights = nn.Parameter(torch.randn(input_dim))

    def forward(self, input_emb, mask=None):
        # TODO: if needed, implement time masking
        batch, timesteps, embed_dim = input_emb.shape
        # w dot q_j
        dot_prods = torch.matmul(input_emb, self.weights)
        # exp(w dot q_j)
        exp_prods = torch.exp(dot_prods)
        # normalization factor
        sum_exp_prods = torch.sum(exp_prods, dim=1)
        sum_exp_prods = sum_exp_prods.repeat(timesteps, 1).T
        # b_j
        b = exp_prods / sum_exp_prods
        # q (embedding) = sum_t(b_t * q_t)
        b_scal_q = input_emb * b[:, :, None]
        # now sum along correct axis
        q = torch.sum(b_scal_q, axis=1)
        return q

**Compatibility functions**

**Multiplicative (dot)**:

p = paragraph emb shape: [B, T, E] (Query)

q = question weighted shape: [B, E] reshaped to [B, E, 1] (Keys)

scores = p @ q (of shape: [B, T, 1])

**General bilinear**:

p = paragraph emb shape: [B, T, Ep] (Query)

q = question weighted shape: [B, Eq] reshaped to [B, Eq, 1] (Keys)

W = parameter matrix of shape: [Ep, Eq]

scores = p @ W @ q (of shape: [B, T, 1])

In [19]:
class BilinearCompatibility(nn.Module):
    def __init__(self, query_dim, keys_dim):
        """
        Perform bilinear compatibility f(q, K) = q.T @ W @ K
        Recall: multiplicative/dot compatibility is f(q, K) = q.T @ K
        
        Where: 
            q -> embedded paragraphs (p in DrQA)
            K -> embedded question (q in DrQA)
        """
        super(BilinearCompatibility, self).__init__()
        self.weights = nn.Parameter(torch.randn(query_dim, keys_dim))

    def forward(self, query, keys):
        """
        query: batch of shape (batch, seq_len, query_dim) (Query)
        keys = batch of shape (batch, key_dim) which will be reshaped into [batch, key_dim, 1] (Keys)
        """
        return query @ self.weights @ keys[:, :, None]

In [20]:
class LSTM_QA(nn.Module):

    def __init__(self, embedding_dim, hidden_dim):
        super(LSTM_QA, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.paragraph_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.question_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.weighted_sum = WeightedSum(hidden_dim * 2)
        # used to compute similarity scores
        self.general_bilinear_start = BilinearCompatibility(hidden_dim * 2, hidden_dim * 2)
        self.general_bilinear_end = BilinearCompatibility(hidden_dim * 2, hidden_dim * 2)
        # to classify from similarity to prob of start and prob of end
        self.sim_to_start = nn.Linear(1, 1) # given a similarity score, predict P(start)
        self.sim_to_end = nn.Linear(1, 1) # given a similarity score, predict P(end)

    def forward(self, inputs):
        # Extract data from inputs dictionary and put it on right device
        curr_device = next(self.parameters()).device # trick to get current device from the params
        paragraphs = inputs["paragraph_emb"].to(curr_device)
        questions = inputs["question_emb"].to(curr_device)
        # Perform the normal forward pass
        batch_size, seq_len, n_feat = paragraphs.shape
        # As we assume batch_first true, then our sentence_embeddings will have correct shape
        paragraphs_seq_emb, _ = self.paragraph_embedder(paragraphs) # (batch, seq_len, n_feats * n_dirs)
        questions_seq_emb, _ = self.question_embedder(questions) # (batch, seq_len, n_feats * n_dirs)
        # weighted sum
        questions_state_repr = self.weighted_sum(questions_seq_emb)
        #return paragraphs_seq_emb, questions_state_repr
        # compute similarities -> (batch, timestep, 1)
        similarities_start = self.general_bilinear_start(paragraphs_seq_emb, questions_state_repr)
        similarities_end = self.general_bilinear_start(paragraphs_seq_emb, questions_state_repr)
        # --- Given a similarity score, predict P(start), P(end) ---
        # similarities flattened
        similarities_start = similarities_start.contiguous()
        similarities_start = similarities_start.view(-1, 1) # as similarity dim is 1 -> viewed shape is (batch*timestep, 1)
        start_scores = self.sim_to_start(similarities_start)
        start_logits = start_scores.view(batch_size, seq_len) # P(start)
        
        similarities_end = similarities_end.contiguous()
        similarities_end = similarities_end.view(-1, 1) # as similarity dim is 1 -> viewed shape is (batch*timestep, 1)
        end_scores = self.sim_to_end(similarities_end)
        end_logits = end_scores.view(batch_size, seq_len) # P(end)
        
        # if we view each sequence of tokens as a feature vector
        # we can interpret the start/end assignation problem as 
        # a classification with a variable number of classes
        # thus assume that our model outputs logits that will just be passed
        # to a softmax, to build a probable distribution of the start token
        return start_logits, end_logits

In [21]:
# Define baseline model
model = LSTM_QA(embedding_model.vector_size, 128).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001, amsgrad=True)

In [22]:
data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
dataset_train_QA = CustomQADataset(data_converter, df_train, paragraphs_mapper)
dataset_val_QA = CustomQADataset(data_converter, df_val, paragraphs_mapper)

In [23]:
train_data_loader = torch.utils.data.DataLoader(dataset_train_QA, collate_fn = padder_collate_fn, batch_size=128, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(dataset_val_QA, collate_fn = padder_collate_fn, batch_size=128, shuffle=True)

In [None]:
history = {
    "train_loss": [], "train_acc_start": [], "train_acc_end": [],
    "val_loss": [], "val_acc_start": [], "val_acc_end": []
}
loop_start = timer()
# lr scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, threshold=0.01)
for epoch in range(50):
    train_dict = train_step(model, optimizer, loss_function, train_data_loader, device=device, show_progress=True)
    val_dict = validation_step(model, loss_function, val_data_loader, device=device, show_progress=True)
    cur_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch: {epoch}, '
          f'lr: {cur_lr}, '
          f'Train loss: {train_dict["loss"]:.4f}, '
          f'Train acc start: {train_dict["accuracy_start"]:.4f}, '
          f'Train acc end: {train_dict["accuracy_end"]:.4f}, '
          f'Val loss: {val_dict["loss"]:.4f}, '
          f'Val acc start: {val_dict["accuracy_start"]:.4f}, '
          f'Val acc end: {val_dict["accuracy_end"]:.4f}, '
          f'Time: {train_dict["time"]:.4f}')
    history["train_loss"].append(train_dict["loss"]);history["train_acc_start"].append(train_dict["accuracy_start"]);history["train_acc_end"].append(train_dict["accuracy_end"]);
    history["val_loss"].append(val_dict["loss"]);history["val_acc_start"].append(val_dict["accuracy_start"]);history["val_acc_end"].append(val_dict["accuracy_end"]);
    #scheduler.step(val_dict["loss"])
loop_end = timer()
print(f"Elapsed time: {(loop_end - loop_start):.4f}")

# Evaluation

Model evaluation on test set

## Quantitative evaluation

In [26]:
test_paragraphs_mapper, df_test = build_mappers_and_dataframe(test_data, limit_answers=1)

In [27]:
test_paragraphs_mapper = preprocess_text(test_paragraphs_mapper)
df_test['question_text'] = df_test.apply(lambda row: nltk.word_tokenize(row['question_text']), axis=1)

In [28]:
# Extend the paragraphs mapper to include spans
test_paragraphs_spans_mapper = add_paragraphs_spans(test_paragraphs_mapper)

In [29]:
data_converter.paragraphs_spans_mapper = test_paragraphs_spans_mapper

In [30]:
dataset_test_QA = CustomQADataset(data_converter, df_test, test_paragraphs_mapper)
test_data_loader = torch.utils.data.DataLoader(dataset_test_QA, collate_fn = padder_collate_fn, batch_size=128, shuffle=True)

In [None]:
with open(test_file_json, "r") as f:
    dataset_json = json.load(f)
pred_dict = build_evaluation_dict(model, test_data_loader, test_paragraphs_mapper, tokenizer, device, show_progress=True)
eval_results = evaluate_predictions(dataset_json, pred_dict)
print(eval_results)