<a href="https://colab.research.google.com/github/alexpod1000/SQuAD-QA/blob/main/ModelTrainExperimentalCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#%%bash
#[[ ! -e /colabtools ]] && exit  # Continue only if running on Google Colab

# Clone repository
# https://sysadmins.co.za/clone-a-private-github-repo-with-personal-access-token/
# For cloning the main branch:
#!git clone https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# For cloning the "evaluation-features" branch
#!git clone --branch evaluation-features https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# Change current working directory to match project
#%cd SQuAD-QA/
#!pwd

In [2]:
# External imports
import copy
import numpy as np
import pandas as pd
import string
import torch

from nltk.tokenize import TreebankWordTokenizer, SpaceTokenizer
from typing import Tuple, List, Dict, Any, Union

# Project imports
from squad_data.parser import SquadFileParser
from squad_data.utils import build_mappers_and_dataframe, add_paragraphs_spans
from evaluation.evaluation_metrics import Evaluator
from evaluation.utils import extract_answer, build_evaluation_dict

### Download Embedding

In [3]:
from utils.embedding_utils import EmbeddingDownloader

embedding_downloader = EmbeddingDownloader(
    "embedding_models", 
    "embedding_model.kv", 
    model_name="fasttext-wiki-news-subwords-300"
)

embedding_model = embedding_downloader.load()

Loading pre-downloaded embeddings from /home/giulio/Documenti/ProgettiGIT/SQuAD-QA/embedding_models/embedding_model.kv
End!
Embedding dimension: 50


### Parse the json and get the data

In [4]:
parser = SquadFileParser("squad_data/data/training_set.json")
data = parser.parse_documents()

########################### DEBUG
# reduce size for faster testing
#full_data = data
#data = []
#for i in range(1): # use only the first 1 documents
#  data.append(full_data[i])

### Prepare the mappers and datafram

In [5]:
paragraphs_mapper, questions_mapper, df = build_mappers_and_dataframe(data)
print(questions_mapper[next(iter(questions_mapper))])
print(paragraphs_mapper[next(iter(paragraphs_mapper))])
df.head()

To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


Unnamed: 0,paragraph_id,question_id,answer_id,answer_start,answer_text
0,0_0,5733be284776f41900661182,0,515,Saint Bernadette Soubirous
1,0_0,5733be284776f4190066117f,0,188,a copper statue of Christ
2,0_0,5733be284776f41900661180,0,279,the Main Building
3,0_0,5733be284776f41900661181,0,381,a Marian place of prayer and reflection
4,0_0,5733be284776f4190066117e,0,92,a golden statue of the Virgin Mary


In [6]:
def preprocess_text(text_dict: Dict[str, Any], text_key: Union[str, None] = None) -> Any:
    text_dict = copy.deepcopy(text_dict)
    # just tokenize and remove punctuation for now
    # TODO: add better punctuation removal later
    tokenizer = SpaceTokenizer()#TreebankWordTokenizer()
    for key in text_dict.keys():
        if text_key is not None:
            text = tokenizer.tokenize(text_dict[key][text_key])
            text_dict[key][text_key] = text
        else:
            text = tokenizer.tokenize(text_dict[key])
            text_dict[key] = text
    return text_dict

In [7]:
paragraphs_mapper = preprocess_text(paragraphs_mapper)
questions_mapper = preprocess_text(questions_mapper)

In [8]:
# Extend the paragraphs mapper to include spans
paragraphs_spans_mapper = add_paragraphs_spans(paragraphs_mapper)

In [9]:
print(paragraphs_spans_mapper['0_0']['text'])
print(paragraphs_spans_mapper['0_0']['spans'])

['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the', 'Main', "Building's", 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'Virgin', 'Mary.', 'Immediately', 'in', 'front', 'of', 'the', 'Main', 'Building', 'and', 'facing', 'it,', 'is', 'a', 'copper', 'statue', 'of', 'Christ', 'with', 'arms', 'upraised', 'with', 'the', 'legend', '"Venite', 'Ad', 'Me', 'Omnes".', 'Next', 'to', 'the', 'Main', 'Building', 'is', 'the', 'Basilica', 'of', 'the', 'Sacred', 'Heart.', 'Immediately', 'behind', 'the', 'basilica', 'is', 'the', 'Grotto,', 'a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflection.', 'It', 'is', 'a', 'replica', 'of', 'the', 'grotto', 'at', 'Lourdes,', 'France', 'where', 'the', 'Virgin', 'Mary', 'reputedly', 'appeared', 'to', 'Saint', 'Bernadette', 'Soubirous', 'in', '1858.', 'At', 'the', 'end', 'of', 'the', 'main', 'drive', '(and', 'in', 'a', 'direct', 'line', 'that', 'connects', 'through', '3', 'statues', 'and', 'the', 'Gold', 'Dome),

### DataConverter and CustomQADataset

In [10]:
from data_loading.utils import DataConverter, padder_collate_fn
from data_loading.qa_dataset import CustomQADataset

data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
datasetQA = CustomQADataset(data_converter, df, paragraphs_mapper, questions_mapper)
data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = padder_collate_fn, batch_size=10, shuffle=True)

test_batch = next(iter(data_loader))
print(test_batch["paragraph_emb"].shape)
print(test_batch["y_gt"].shape)

torch.Size([10, 233, 50])
torch.Size([10, 2])


# Model train

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from timeit import default_timer as timer

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device is {device}")

The device is cpu


  return torch._C._cuda_getDeviceCount() > 0


Model:

(paragraph_emb, question_emb) -> (answer_start, answer_end) // for each token in paragraph_emb

In [13]:
def train_step(model, optimizer, loss_function, dataloader, device="cpu"):
    acc_loss = 0
    acc_start_accuracy = 0
    acc_end_accuracy = 0
    count = 0

    time_start = timer()
    
    model.train()
    for batch in dataloader:
        paragraph_in = batch["paragraph_emb"]
        question_in = batch["question_emb"]
        answer_spans_start = batch["y_gt"][:, 0]
        answer_spans_end = batch["y_gt"][:, 1]
        # Clear gradients
        model.zero_grad()
        # Place to right device
        paragraph_in = paragraph_in.to(device)
        question_in = question_in.to(device)
        answer_spans_start = answer_spans_start.to(device)
        answer_spans_end = answer_spans_end.to(device)
        # Run forward pass
        pred_answer_start_scores, pred_answer_end_scores = model(paragraph_in, question_in)
        # Compute the CrossEntropyLoss
        loss = loss_function(pred_answer_start_scores, answer_spans_start) + loss_function(pred_answer_end_scores, answer_spans_end)
        # Compute gradients
        loss.backward()
        # Optimizer step
        optimizer.step()
        # --- Compute metrics ---
        # Get span indexes
        pred_span_start_idxs = torch.argmax(pred_answer_start_scores, axis=-1).cpu().detach()
        pred_span_end_idxs = torch.argmax(pred_answer_end_scores, axis=-1).cpu().detach()
        gt_start_idxs = answer_spans_start.cpu().detach()
        gt_end_idxs = answer_spans_end.cpu().detach()
        # two accs
        start_accuracy = torch.sum(gt_start_idxs == pred_span_start_idxs) / len(pred_span_start_idxs)
        end_accuracy = torch.sum(gt_end_idxs == pred_span_end_idxs) / len(pred_span_end_idxs)
        # Gather stats
        acc_loss += loss.item()
        acc_start_accuracy += start_accuracy.item()
        acc_end_accuracy += end_accuracy.item()
        count += 1
    time_end = timer()
    return {
        "loss": acc_loss / count, 
        "accuracy_start": acc_start_accuracy / count, 
        "accuracy_end": acc_end_accuracy / count,
        "time": time_end - time_start
    }

In [14]:
# create Evaluator object
evaluator = Evaluator(documents_list=data)

In [15]:
def evaluate_model_on_data(model, evaluator, dataloader, paragraphs_mapper, questions_mapper, device):
    eval_dict = build_evaluation_dict(model, dataloader, paragraphs_mapper, questions_mapper, device)
    print(f"DEBUG: Eval_dict: {eval_dict}")
    stats = {}
    stats['exact_match'] = evaluator.ExactMatch(eval_dict)
    stats['f1'] = evaluator.F1(eval_dict)
    return stats

In [16]:
class WeightedSum(nn.Module):
    def __init__(self, input_dim):
        """
        General idea, given a random dummy weights vector, 
        learn to weight it based on query
        """
        super(WeightedSum, self).__init__()
        self.weights = nn.Parameter(torch.randn(input_dim))

    def forward(self, input_emb, mask=None):
        # TODO: if needed, implement time masking
        batch, timesteps, embed_dim = input_emb.shape
        # w dot q_j
        dot_prods = torch.matmul(input_emb, self.weights)
        # exp(w dot q_j)
        exp_prods = torch.exp(dot_prods)
        # normalization factor
        sum_exp_prods = torch.sum(exp_prods, dim=1)
        sum_exp_prods = sum_exp_prods.repeat(timesteps, 1).T
        # b_j
        b = exp_prods / sum_exp_prods
        # q (embedding) = sum_t(b_t * q_t)
        b_scal_q = input_emb * b[:, :, None]
        # now sum along correct axis
        q = torch.sum(b_scal_q, axis=1)
        return q

In [17]:
class LSTM_QA(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, tagset_size):
        super(LSTM_QA, self).__init__()
        self.tagset_size = tagset_size
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.paragraph_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.question_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.weighted_sum = WeightedSum(hidden_dim * 2)
        # to classify from similarity to prob of start and prob of end
        self.sim_to_prob = nn.Linear(1, 2) # given a similarity score, predict P(start), P(end)

    def forward(self, paragraphs, questions):
        batch_size, seq_len, n_feat = paragraphs.shape
        # As we assume batch_first true, then our sentence_embeddings will have correct shape
        paragraphs_seq_emb, _ = self.paragraph_embedder(paragraphs) # (batch, seq_len, n_feats * n_dirs)
        questions_seq_emb, _ = self.question_embedder(questions) # (batch, seq_len, n_feats * n_dirs)
        # weighted sum
        questions_state_repr = self.weighted_sum(questions_seq_emb)
        # compute similarities -> (batch, timestep, 1)
        similarities = torch.bmm(paragraphs_seq_emb, questions_state_repr[:, :, None])
        #print(f"INSIDE MODEL: similarities shape: {similarities.shape}") #DEBUG
        # --- Given a similarity score, predict P(start), P(end) ---
        # similarities flattened
        similarities = similarities.contiguous()
        similarities = similarities.view(-1, 1) # as similarity dim is 1 -> viewed shape is (batch*timestep, 1)
        start_end_scores = self.sim_to_prob(similarities)
        start_end_scores = start_end_scores.view(batch_size, seq_len, 2) # where 2 is (P(start), P(end))
        
        start_logits = start_end_scores[:, :, 0]
        end_logits = start_end_scores[:, :, 1]
        
        # if we view each sequence of tokens as a feature vector
        # we can interpret the start/end assignation problem as 
        # a classification with a variable number of classes
        # thus assume that our model outputs logits that will just be passed
        # to a softmax, to build a probable distribution of the start token
        return start_logits, end_logits
        
        #self.sim_to_prob
        
        # we can for each similarity score predict just two scalars (simple 1 to 2 mapping NN), and use P(start)[idx_of_start] = 1, rest 0
        # and P(end)[idx_of_end] = 1, rest 0 for the groundtruth
        
        # Crude question repr: take last state of lstm (remove padding), concat it (as it's a bilstm) and use it as question representation
        # 
        #return paragraphs_seq_emb, questions_seq_emb, questions_state_repr, similarities
        #
        # THANK YOU DUDE: https://www.kdnuggets.com/2018/06/taming-lstms-variable-sized-mini-batches-pytorch.html
        # Project to tag space
        # Dim transformation: (batch_size, seq_len, nb_lstm_units) -> (batch_size * seq_len, nb_lstm_units)
        # this one is a bit tricky as well. First we need to reshape the data so it goes into the linear layer
        #
        #lstm_out = lstm_out.contiguous()
        #lstm_out = lstm_out.view(-1, lstm_out.shape[2])
        #lstm_out = self.dropout(lstm_out)
        # Run through actual linear layer
        #tag_logits = self.hidden_to_tag(lstm_out)
        # Dim transformation: (batch_size * seq_len, nb_lstm_units) -> (batch_size, seq_len, nb_tags)
        #tag_logits = tag_logits.view(batch_size, seq_len, self.tagset_size)
        #return tag_logits

In [18]:
#torch.nn.functional.softmax(outs_mod[0])

In [19]:
# Define baseline model
model = LSTM_QA(embedding_model.vector_size, 128, 10).to(device)
# NOTE: weight=torch.Tensor(class_weights_train).to(device) sucks badly, don't use it, it fucks up performance completly
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001, amsgrad=True)

In [20]:
data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
datasetQA = CustomQADataset(data_converter, df, paragraphs_mapper, questions_mapper)

In [21]:
train_data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = padder_collate_fn, batch_size=64, shuffle=True)

In [22]:
history = {"train_loss": [], "train_acc_start": [], "train_acc_end": []}
loop_start = timer()
# lr scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, threshold=0.01)
for epoch in range(50):
    train_dict = train_step(model, optimizer, loss_function, train_data_loader, device=device)
    eval_results = evaluate_model_on_data(model, evaluator, train_data_loader, paragraphs_mapper, questions_mapper, device)
    cur_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch: {epoch}, lr: {cur_lr}, Train loss: {train_dict["loss"]:.4f},  Train acc start: {train_dict["accuracy_start"]:.4f}, Train acc end: {train_dict["accuracy_end"]:.4f}, Time: {train_dict["time"]:.4f}')
    history["train_loss"].append(train_dict["loss"]);history["train_acc_start"].append(train_dict["accuracy_start"]);history["train_acc_end"].append(train_dict["accuracy_end"]);
    #history["val_loss"].append(val_dict["loss"]);history["val_acc"].append(val_dict["accuracy"]);
    #scheduler.step(val_dict["loss"])
    print(f"Evaluation Results: {eval_results}")
loop_end = timer()
print(f"Elapsed time: {(loop_end - loop_start):.4f}")

DEBUG: Eval_dict: {'573398164776f41900660e21': 'Luigi', '5733b699d058e614000b6118': 'About', '573393184776f41900660da7': '$9.7', '573388ce4776f41900660cc6': 'regalia.', '573394c84776f41900660ddd': '(1987–2005),', '5733b3d64776f419006610a6': 'F.', '57339b36d058e614000b5ea6': '2012[update]', '57338a51d058e614000b5cf4': 'Father', '573387acd058e614000b5cb5': 'Division', '5733a6424776f41900660f4e': 'offered.', '573385394776f41900660c80': 'Father', '5733a70c4776f41900660f63': 'First', '5733cd504776f41900661290': 'F.', '5733c743d058e614000b622d': 'Army', '57339b36d058e614000b5ea4': '2012[update]', '5733bf84d058e614000b61c1': "Mary's", '5733ad384776f41900660fec': '14-story', '5733afd3d058e614000b6046': 'Best', '5733974d4776f41900660e17': '2005,', '573382a14776f41900660c2e': 'Notre', '5733c3184776f419006611c4': 'On', '57339a5bd058e614000b5e91': 'million', '573388ce4776f41900660cc5': '1924.', '5733cd504776f41900661292': 'basketball,', '5733974d4776f41900660e18': 'Jenkins,', '573387acd058e614000b

KeyboardInterrupt: 