<a href="https://colab.research.google.com/github/alexpod1000/SQuAD-QA/blob/main/ModelTrainExperimentalCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#%%bash
#[[ ! -e /colabtools ]] && exit  # Continue only if running on Google Colab

# Clone repository
# https://sysadmins.co.za/clone-a-private-github-repo-with-personal-access-token/
# For cloning the main branch:
#!git clone https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# For cloning the "evaluation-features" branch
#!git clone --branch evaluation-features https://fb5b65b126107273e595ce8b6c9d2d533103c6e2:x-oauth-basic@github.com/alexpod1000/SQuAD-QA.git
# Change current working directory to match project
#%cd SQuAD-QA/
#!pwd

In [2]:
# External imports
import copy
import nltk
import numpy as np
import pandas as pd
import string
import torch

from nltk.tokenize import TreebankWordTokenizer, SpaceTokenizer
from typing import Tuple, List, Dict, Any, Union

# Project imports
from squad_data.parser import SquadFileParser
from squad_data.utils import build_mappers_and_dataframe, add_paragraphs_spans
from evaluation.evaluation_metrics import Evaluator
from evaluation.utils import extract_answer, build_evaluation_dict

### Download Embedding

In [3]:
from utils.embedding_utils import EmbeddingDownloader

embedding_downloader = EmbeddingDownloader(
    "embedding_models", 
    "embedding_model.kv", 
    model_name="fasttext-wiki-news-subwords-300"
)

embedding_model = embedding_downloader.load()

Loading pre-downloaded embeddings from /home/alexpod/uni/magistrale_ai/secondo_anno/nlp/project/SQuAD-QA/embedding_models/embedding_model.kv
End!
Embedding dimension: 300


### Parse the json and get the data

In [4]:
parser = SquadFileParser("squad_data/data/training_set.json")
data = parser.parse_documents()

########################### DEBUG
# reduce size for faster testing
#full_data = data
#data = []
#for i in range(1): # use only the first 1 documents
#  data.append(full_data[i])

### Prepare the mappers and datafram

In [31]:
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from typing import Union, Tuple, List, Dict, Any
from squad_data import Document

def index_of_first(lst, pred):
    for i, v in enumerate(lst):
        if pred(v):
            return i
    return None

def split_paragraph_if_needed(paragraph, question, answer_span, tokenizer):
    """
    Attempts to tokenize a paragraph and question together, if too long
    because of tokenizer's max length, then will split the paragraph into
    multiple slices.
    
    Returns a list of paragraph slices with answer span, such that:
        - a paragraph slice with no answer will have answer mapped to (CLS, CLS)
        - a paragraph slice with answer will be mapped to the index of answer.
    """
    max_length = 384
    doc_stride = 128
    pad_on_right = tokenizer.padding_side == "right"
    # TODO: add max_length and doc_stride
    tokenized_input_pair = tokenizer(
        question,
        paragraph,
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    # outputs
    paragraph_splits = []
    answer_spans = []
    # get answer end char idx
    ans_start = answer_span[0]
    ans_end = answer_span[1]
    """
    1) Find index of context segments in the tokenized example
    2) Within the context segments (start from context_segment_idx), 
       find the token corresponding to span of answer: start and end.
    """
    for offset_idx, offset in enumerate(tokenized_input_pair.offset_mapping):
        # get sequence ids
        sequence_ids = tokenized_input_pair.sequence_ids(offset_idx)
        # find start index of context segment
        context_segment_idx = sequence_ids.index(1)
        span_start_offset_idx = index_of_first(
            tokenized_input_pair.offset_mapping[offset_idx][context_segment_idx:], 
            lambda span: span[0] <= ans_start <= span[1]
        )
        span_end_offset_idx = index_of_first(
            tokenized_input_pair.offset_mapping[offset_idx][context_segment_idx:], 
            lambda span: span[0] <= ans_end <= span[1]
        )
        # Decode split into a string
        decoded_split = tokenizer.decode(tokenized_input_pair.input_ids[offset_idx][context_segment_idx:], skip_special_tokens=True)
        # 
        paragraph_splits.append(decoded_split)
        if span_start_offset_idx is not None and span_end_offset_idx is not None:
            # If answer span is fully in current slice
            # add segment idx offset
            span_start_offset_idx += context_segment_idx
            span_end_offset_idx += context_segment_idx + 1 # the plus 1 is needed for correct slicing
            answer_spans.append((span_start_offset_idx, span_end_offset_idx))
            # TODO: we have span indexes now, map them to groundtruth (?)
        elif span_start_offset_idx is None and span_end_offset_idx is None:
            # If span not in this slice, but in another slice
            # map answer to (CLS, CLS)
            cls_idx = tokenized_input_pair.input_ids[offset_idx].index(tokenizer.cls_token_id)
            # NOTE(Alex): although I think it's always 0
            answer_spans.append((cls_idx, cls_idx))
        else:
            # span spans along multiple slices -> throw the sample away 
            # (should be only like 4 samples across the whole dataset)
            # Discard sample
            pass
    
    return (paragraph_splits, answer_spans)

def build_bert_mappers_and_dataframe(
    tokenizer: PreTrainedTokenizerFast,
    documents_list: List[Document], 
    limit_answers: int = -1
    ) -> Tuple[Dict[str, str], pd.DataFrame]:
    """
    Given a list of SQuAD Document objects, returns mapper to transform from
    paragraph id to paragraph text and a dataframe containing paragraph id, 
    question id, text and answer details.
    Args:
        documents_list (List[Document]): list of parsed SQuAD document objects.
        limit_answers (int): limit number of returned answers per question
            to this amount (-1 to return all the available answers).
        tokenizer: Huggingface tokenizer

    Returns:
        paragraphs_mapper: mapper from paragraph id to paragraph text
        dataframe: Pandas dataframe with the following schema
            (paragraph_id, question_id, question_text, answer_id, answer_start, answer_text)
    """

    # type for np array: np.ndarray
    # given a paragraph id, maps the paragraph to its text or embeddings (or both)
    split_paragraphs_mapper = {}
    # dataframe
    dataframe_list = []
    for doc_idx, document in enumerate(documents_list):
        # for each paragraph
        for par_idx, paragraph in enumerate(document.paragraphs):
            par_text = paragraph.context.strip()

            # for each question
            for question in paragraph.questions:
                question_id = question.id
                question_text = question.question.strip()

                # TODO(Alex): at this level, could need to start building par/question

                # take only "limit_answers" answers for every question.
                answer_range = len(question.answers) if limit_answers == -1 else limit_answers
                for answer_id, answer in enumerate(question.answers[:answer_range]):
                    # NOTE: in training set, there's only one answer per question.
                    answer_text = answer.text.strip()
                    # get span
                    answer_start = answer.answer_start
                    answer_end = answer.answer_start + len(answer_text)

                    par_splits, split_answer_spans = split_paragraph_if_needed(par_text, question_text, (answer_start, answer_end), tokenizer)
                    
                    pair_overflows = len(par_splits) > 1
                    
                    for split_idx, (split_text, split_ans_span) in enumerate(zip(par_splits, split_answer_spans)):
                        """
                        NOTE(Alex): since in tokenization phase we also use question, our ID depends on question too
                                    For example if for question1, the pair <question1, par> goes above the limit,
                                    but for <question2, par> it does not, then we'll still need to keep track of
                                    different splits of par, depending on each question.
                        """
                        if pair_overflows:
                            split_par_id = "{}_{}_{}_{}".format(doc_idx, par_idx, question_id, split_idx)
                        else:
                            """
                            If no length overflow, then we don't need question_id or split_idx
                            To optimize memory, we can map same splits to same id 
                            (some pairs <question, par> won't overflow anyway)
                            """
                            split_par_id = "{}_{}".format(doc_idx, par_idx)
                        split_paragraphs_mapper[split_par_id] = split_text
                        # build dataframe entry
                        dataframe_list.append({
                            "paragraph_id": split_par_id,
                            "question_id": question_id,
                            "answer_id": answer_id,
                            "answer_start": answer_start,
                            "answer_text": answer_text,
                            "question_text": question_text,
                            "tokenizer_answer_start": split_ans_span[0],
                            "tokenizer_answer_end": split_ans_span[1],
                        })
    return split_paragraphs_mapper, pd.DataFrame(dataframe_list)

In [26]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [32]:
#par_text, ques_text, ans_start, ans_text = build_bert_mappers_and_dataframe(tokenizer, data, limit_answers=1)
paragraphs_mapper, df = build_bert_mappers_and_dataframe(tokenizer, data, limit_answers=1)

In [33]:
paragraphs_mapper

{'0_0': 'architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary.',
 '0_1': "as at most other universities, notre dame's students run a number of news media outlets. the nine student - run outlets include three newspapers, both a radio and television station, and several magazines and journals. begun as a one - page journal in september 1876, the scholas

In [22]:
fuck

(["the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fifteenth at notre dame, has achieved a 332 - 165 record. in 2009 they were invited to t

In [16]:
tokenized_example = tokenizer(
    ques_text,
    par_text * 3,
    truncation="only_second" if tokenizer.padding_side == "right" else "only_first",
    max_length=384,
    stride=128,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
)

In [24]:
len(tokenized_example.offset_mapping)

2

In [20]:
sids = tokenized_example.sequence_ids()

In [19]:
tokenized_example

{'input_ids': [[101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 23052, 1010, 1037, 14042, 2173, 1997, 7083, 1998, 9185, 1012, 2009, 2003, 1037, 15059, 1997, 1996, 24665, 23052, 2012, 10223, 26371, 1010, 2605, 2073, 1996, 6261, 2984, 22353, 2135, 2596, 2000, 3002, 16595, 9648, 4674, 2061, 12083, 9711, 2271, 1999, 8517, 1012, 2012, 1996, 2203, 1997, 1996, 2364, 3298, 1006, 1998, 1999, 1037, 3622, 2240, 2008, 8539, 2083, 1017, 11342, 1998, 1996, 2

In [9]:
def index_of_first(lst, pred):
    for i,v in enumerate(lst):
        if pred(v):
            return i
    return None

In [31]:
help(tokenizer.decode)

Help on method decode in module transformers.tokenization_utils_base:

decode(token_ids: Union[int, List[int], ForwardRef('np.ndarray'), ForwardRef('torch.Tensor'), ForwardRef('tf.Tensor')], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, **kwargs) -> str method of transformers.models.distilbert.tokenization_distilbert_fast.DistilBertTokenizerFast instance
    Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
    tokens and clean up tokenization spaces.
    
    Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
    
    Args:
        token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
            List of tokenized input ids. Can be obtained using the ``__call__`` method.
        skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to remove special tokens in the decoding.
        clean_up

In [34]:
tokenizer.decode(tokenized_example.input_ids[offset_idx][context_segment_idx:], skip_special_tokens=True)

'the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple

In [35]:
"""
1) Find index of context segments in the tokenized example
2) Within the context segments (start from context_segment_idx), 
   find the token corresponding to span of answer: start and end.
"""
paragraph_splits = []
answer_spans = []
# get answer end char idx
ans_end = ans_start + len(ans_text)
for offset_idx, offset in enumerate(tokenized_example.offset_mapping):
    # get sequence ids
    sequence_ids = tokenized_example.sequence_ids(offset_idx)
    # find start index of context segment
    context_segment_idx = sequence_ids.index(1)
    span_start_offset_idx = index_of_first(tokenized_example.offset_mapping[offset_idx][context_segment_idx:], lambda span: span[0] <= ans_start <= span[1])
    span_end_offset_idx = index_of_first(tokenized_example.offset_mapping[offset_idx][context_segment_idx:], lambda span: span[0] <= ans_end <= span[1])
    # TODO: check if both of the indexes is None -> the answer is not in this span
    # TODO: check if one of the indexes is None -> the answer spans among multiple slices -> remove sample from dataframe
    # add segment idx offset
    decoded_split = tokenizer.decode(tokenized_example.input_ids[offset_idx][context_segment_idx:], skip_special_tokens=True)
    paragraph_splits.append(decoded_split)
    if span_start_offset_idx is not None and span_end_offset_idx is not None:
        span_start_offset_idx += context_segment_idx
        span_end_offset_idx += context_segment_idx + 1 # the plus 1 is needed for correct slicing
        print("Span in slice",offset_idx, "span", span_start_offset_idx, span_end_offset_idx)
        answer_spans.append((span_start_offset_idx, span_end_offset_idx))
        # TODO: we have span indexes now, map them to groundtruth (?)
    elif span_start_offset_idx is None and span_end_offset_idx is None:
        # span not in this slice, but in another slice
        # TODO: map answer to (CLS, CLS)
        print("Span not in slice", offset_idx)
        cls_idx = tokenized_example.input_ids[offset_idx].index(tokenizer.cls_token_id)
        # NOTE(Alex): although I think it's always 0
        answer_spans.append((cls_idx, cls_idx))
    else:
        # span spans along multiple slices -> throw the sample away (should be only like 4 samples across the whole dataset)
        # TODO: discard sample
        pass
"""
NOTE: this logic is used for sample creation only, such that each sample is "short enough" for BERT; 
      a duplicate of this logic will need to be used in QADataset Dataloader class when we'll take
      short samples' text, tokenize them again, and find the correct index
ALTERNATIVE: for BERT models we could directly get the answer spans, and pass them in dataframe to another QADataset
             built specifically for BERT, that will just take the data from dataframe (way nicer and faster solution).
SUGGESTION: we could also use specific dict keys and in QADataset pick stuff from these keys: 
                - if these keys are absent then don't use BERT logic (eg span_start and span_end) and use previous logic
                - if these keys are present, then just use them and gather the BERT samples.
                Call these keys like "tokenizer_span_idx" (to make them kinda unique)
"""

Span in slice 0 span 130 138
Span not in slice 1


'\nNOTE: this logic is used for sample creation only, such that each sample is "short enough" for BERT; \n      a duplicate of this logic will need to be used in QADataset Dataloader class when we\'ll take\n      short samples\' text, tokenize them again, and find the correct index\nALTERNATIVE: for BERT models we could directly get the answer spans, and pass them in dataframe to another QADataset\n             built specifically for BERT, that will just take the data from dataframe (way nicer and faster solution).\nSUGGESTION: we could also use specific dict keys and in QADataset pick stuff from these keys: \n                - if these keys are absent then don\'t use BERT logic (eg span_start and span_end) and use previous logic\n                - if these keys are present, then just use them and gather the BERT samples.\n                Call these keys like "tokenizer_span_idx" (to make them kinda unique)\n'

In [36]:
paragraph_splits

['architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to saint bernadette soubirous in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. architecturally, the school has a catholic character. atop the main building\'s gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main buildi

In [64]:
tokenizer.decode(tokenized_example.input_ids[0][span_start_offset_idx: span_end_offset_idx])

'saint bernadette soubirous'

In [44]:
tokenized_example.offset_mapping[0][span_start_offset_idx: span_end_offset_idx]

[(438, 440),
 (441, 444),
 (445, 447),
 (447, 451),
 (452, 454),
 (455, 458),
 (458, 462)]

In [31]:
tokenized_example.offset_mapping[0][context_segment_idx:]

[(0, 13),
 (13, 15),
 (15, 16),
 (17, 20),
 (21, 27),
 (28, 31),
 (32, 33),
 (34, 42),
 (43, 52),
 (52, 53),
 (54, 58),
 (59, 62),
 (63, 67),
 (68, 76),
 (76, 77),
 (77, 78),
 (79, 83),
 (84, 88),
 (89, 91),
 (92, 93),
 (94, 100),
 (101, 107),
 (108, 110),
 (111, 114),
 (115, 121),
 (122, 126),
 (126, 127),
 (128, 139),
 (140, 142),
 (143, 148),
 (149, 151),
 (152, 155),
 (156, 160),
 (161, 169),
 (170, 173),
 (174, 180),
 (181, 183),
 (183, 184),
 (185, 187),
 (188, 189),
 (190, 196),
 (197, 203),
 (204, 206),
 (207, 213),
 (214, 218),
 (219, 223),
 (224, 226),
 (226, 229),
 (229, 232),
 (233, 237),
 (238, 241),
 (242, 248),
 (249, 250),
 (250, 252),
 (252, 254),
 (254, 256),
 (257, 259),
 (260, 262),
 (263, 265),
 (265, 268),
 (268, 269),
 (269, 270),
 (271, 275),
 (276, 278),
 (279, 282),
 (283, 287),
 (288, 296),
 (297, 299),
 (300, 303),
 (304, 312),
 (313, 315),
 (316, 319),
 (320, 326),
 (327, 332),
 (332, 333),
 (334, 345),
 (346, 352),
 (353, 356),
 (357, 365),
 (366, 368),
 (

In [20]:
tokenized_example

{'input_ids': [[101, 2000, 3183, 2106, 1996, 6261, 2984, 9382, 3711, 1999, 8517, 1999, 10223, 26371, 2605, 1029, 102, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 23052, 1010, 1037, 14042, 2173, 1997, 7083, 1998, 9185, 1012, 2009, 2003, 1037, 15059, 1997, 1996, 24665, 23052, 2012, 10223, 26371, 1010, 2605, 2073, 1996, 6261, 2984, 22353, 2135, 2596, 2000, 3002, 16595, 9648, 4674, 2061, 12083, 9711, 2271, 1999, 8517, 1012, 2012, 1996, 2203, 1997, 1996, 2364, 3298, 1006, 1998, 1999, 1037, 3622, 2240, 2008, 8539, 2083, 1017, 11342, 1998, 1996, 2

In [11]:
ques_text

'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?'

In [5]:
paragraphs_mapper, df = build_mappers_and_dataframe(data, limit_answers=1)
print(paragraphs_mapper[next(iter(paragraphs_mapper))])
df.head()

Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


Unnamed: 0,paragraph_id,question_id,answer_id,answer_start,answer_text,question_text
0,0_0,5733be284776f41900661182,0,515,Saint Bernadette Soubirous,To whom did the Virgin Mary allegedly appear i...
1,0_0,5733be284776f4190066117f,0,188,a copper statue of Christ,What is in front of the Notre Dame Main Building?
2,0_0,5733be284776f41900661180,0,279,the Main Building,The Basilica of the Sacred heart at Notre Dame...
3,0_0,5733be284776f41900661181,0,381,a Marian place of prayer and reflection,What is the Grotto at Notre Dame?
4,0_0,5733be284776f4190066117e,0,92,a golden statue of the Virgin Mary,What sits on top of the Main Building at Notre...


In [6]:
def preprocess_text(text_dict: Dict[str, Any], text_key: Union[str, None] = None) -> Any:
    text_dict = copy.deepcopy(text_dict)
    # just tokenize and remove punctuation for now
    # TODO: add better punctuation removal later
    tokenizer = SpaceTokenizer()#TreebankWordTokenizer()
    for key in text_dict.keys():
        if text_key is not None:
            text = tokenizer.tokenize(text_dict[key][text_key])
            text_dict[key][text_key] = text
        else:
            text = tokenizer.tokenize(text_dict[key])
            text_dict[key] = text
    return text_dict

In [7]:
paragraphs_mapper = preprocess_text(paragraphs_mapper)
df['question_text'] = df.apply(lambda row: nltk.word_tokenize(row['question_text']), axis=1)

In [8]:
# Extend the paragraphs mapper to include spans
paragraphs_spans_mapper = add_paragraphs_spans(paragraphs_mapper)

In [9]:
print(paragraphs_spans_mapper['0_0']['text'])
print(paragraphs_spans_mapper['0_0']['spans'])

['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the', 'Main', "Building's", 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'Virgin', 'Mary.', 'Immediately', 'in', 'front', 'of', 'the', 'Main', 'Building', 'and', 'facing', 'it,', 'is', 'a', 'copper', 'statue', 'of', 'Christ', 'with', 'arms', 'upraised', 'with', 'the', 'legend', '"Venite', 'Ad', 'Me', 'Omnes".', 'Next', 'to', 'the', 'Main', 'Building', 'is', 'the', 'Basilica', 'of', 'the', 'Sacred', 'Heart.', 'Immediately', 'behind', 'the', 'basilica', 'is', 'the', 'Grotto,', 'a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflection.', 'It', 'is', 'a', 'replica', 'of', 'the', 'grotto', 'at', 'Lourdes,', 'France', 'where', 'the', 'Virgin', 'Mary', 'reputedly', 'appeared', 'to', 'Saint', 'Bernadette', 'Soubirous', 'in', '1858.', 'At', 'the', 'end', 'of', 'the', 'main', 'drive', '(and', 'in', 'a', 'direct', 'line', 'that', 'connects', 'through', '3', 'statues', 'and', 'the', 'Gold', 'Dome),

In [10]:
len(df)

87599

### DataConverter and CustomQADataset

In [11]:
from data_loading.utils import DataConverter, padder_collate_fn
from data_loading.qa_dataset import CustomQADataset

data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
datasetQA = CustomQADataset(data_converter, df, paragraphs_mapper)
data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = padder_collate_fn, batch_size=10, shuffle=True)

test_batch = next(iter(data_loader))
print(test_batch["paragraph_emb"].shape)
print(test_batch["y_gt"].shape)

torch.Size([10, 145, 300])
torch.Size([10, 2])


# Model train

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from timeit import default_timer as timer

from models.utils import SpanExtractor

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"The device is {device}")

The device is cuda


Model:

(paragraph_emb, question_emb) -> (answer_start, answer_end) // for each token in paragraph_emb

In [14]:
def train_step(model, optimizer, loss_function, dataloader, device="cpu"):
    acc_loss = 0
    acc_start_accuracy = 0
    acc_end_accuracy = 0
    count = 0

    time_start = timer()
    
    model.train()
    for batch in dataloader:
        paragraph_in = batch["paragraph_emb"]
        question_in = batch["question_emb"]
        answer_spans_start = batch["y_gt"][:, 0]
        answer_spans_end = batch["y_gt"][:, 1]
        # Clear gradients
        model.zero_grad()
        # Place to right device
        paragraph_in = paragraph_in.to(device)
        question_in = question_in.to(device)
        answer_spans_start = answer_spans_start.to(device)
        answer_spans_end = answer_spans_end.to(device)
        # Run forward pass
        pred_answer_start_scores, pred_answer_end_scores = model(paragraph_in, question_in)
        # Compute the CrossEntropyLoss
        loss = loss_function(pred_answer_start_scores, answer_spans_start) + loss_function(pred_answer_end_scores, answer_spans_end)
        # Compute gradients
        loss.backward()
        # Optimizer step
        optimizer.step()
        # --- Compute metrics ---
        # Get span indexes
        pred_span_start_idxs, pred_span_end_idxs = SpanExtractor.extract_most_probable(pred_answer_start_scores, pred_answer_end_scores)
        gt_start_idxs = answer_spans_start.cpu().detach()
        gt_end_idxs = answer_spans_end.cpu().detach()
        # two accs
        start_accuracy = torch.sum(gt_start_idxs == pred_span_start_idxs) / len(pred_span_start_idxs)
        end_accuracy = torch.sum(gt_end_idxs == pred_span_end_idxs) / len(pred_span_end_idxs)
        # Gather stats
        acc_loss += loss.item()
        acc_start_accuracy += start_accuracy.item()
        acc_end_accuracy += end_accuracy.item()
        count += 1
    time_end = timer()
    return {
        "loss": acc_loss / count, 
        "accuracy_start": acc_start_accuracy / count, 
        "accuracy_end": acc_end_accuracy / count,
        "time": time_end - time_start
    }

In [15]:
# create Evaluator object
evaluator = Evaluator(documents_list=data)

In [29]:
def evaluate_model_on_data(model, evaluator, dataloader, paragraphs_mapper, device, debug=False):
    eval_dict = build_evaluation_dict(model, dataloader, paragraphs_mapper, device)
    if debug:
        print(f"DEBUG: Eval_dict: {eval_dict}")
    stats = {}
    stats['exact_match'] = evaluator.ExactMatch(eval_dict)
    stats['f1'] = evaluator.F1(eval_dict)
    return stats

In [17]:
class WeightedSum(nn.Module):
    def __init__(self, input_dim):
        """
        General idea, given a random dummy weights vector, 
        learn to weight it based on query
        """
        super(WeightedSum, self).__init__()
        self.weights = nn.Parameter(torch.randn(input_dim))

    def forward(self, input_emb, mask=None):
        # TODO: if needed, implement time masking
        batch, timesteps, embed_dim = input_emb.shape
        # w dot q_j
        dot_prods = torch.matmul(input_emb, self.weights)
        # exp(w dot q_j)
        exp_prods = torch.exp(dot_prods)
        # normalization factor
        sum_exp_prods = torch.sum(exp_prods, dim=1)
        sum_exp_prods = sum_exp_prods.repeat(timesteps, 1).T
        # b_j
        b = exp_prods / sum_exp_prods
        # q (embedding) = sum_t(b_t * q_t)
        b_scal_q = input_emb * b[:, :, None]
        # now sum along correct axis
        q = torch.sum(b_scal_q, axis=1)
        return q

**Compatibility functions**

**Multiplicative (dot)**:

p = paragraph emb shape: [B, T, E] (Query)

q = question weighted shape: [B, E] reshaped to [B, E, 1] (Keys)

scores = p @ q (of shape: [B, T, 1])

**General bilinear**:

p = paragraph emb shape: [B, T, Ep] (Query)

q = question weighted shape: [B, Eq] reshaped to [B, Eq, 1] (Keys)

W = parameter matrix of shape: [Ep, Eq]

scores = p @ W @ q (of shape: [B, T, 1])

In [18]:
class BilinearCompatibility(nn.Module):
    def __init__(self, query_dim, keys_dim):
        """
        Perform bilinear compatibility f(q, K) = q.T @ W @ K
        Recall: multiplicative/dot compatibility is f(q, K) = q.T @ K
        
        Where: 
            q -> embedded paragraphs (p in DrQA)
            K -> embedded question (q in DrQA)
        """
        super(BilinearCompatibility, self).__init__()
        self.weights = nn.Parameter(torch.randn(query_dim, keys_dim))

    def forward(self, query, keys):
        """
        query: batch of shape (batch, seq_len, query_dim) (Query)
        keys = batch of shape (batch, key_dim) which will be reshaped into [batch, key_dim, 1] (Keys)
        """
        return query @ self.weights @ keys[:, :, None]

In [19]:
class LSTM_QA(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, tagset_size):
        super(LSTM_QA, self).__init__()
        self.tagset_size = tagset_size
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.paragraph_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.question_embedder = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.weighted_sum = WeightedSum(hidden_dim * 2)
        # used to compute similarity scores
        self.general_bilinear_start = BilinearCompatibility(hidden_dim * 2, hidden_dim * 2)
        self.general_bilinear_end = BilinearCompatibility(hidden_dim * 2, hidden_dim * 2)
        # to classify from similarity to prob of start and prob of end
        self.sim_to_start = nn.Linear(1, 1) # given a similarity score, predict P(start)
        self.sim_to_end = nn.Linear(1, 1) # given a similarity score, predict P(end)

    def forward(self, paragraphs, questions):
        batch_size, seq_len, n_feat = paragraphs.shape
        # As we assume batch_first true, then our sentence_embeddings will have correct shape
        paragraphs_seq_emb, _ = self.paragraph_embedder(paragraphs) # (batch, seq_len, n_feats * n_dirs)
        questions_seq_emb, _ = self.question_embedder(questions) # (batch, seq_len, n_feats * n_dirs)
        # weighted sum
        questions_state_repr = self.weighted_sum(questions_seq_emb)
        #return paragraphs_seq_emb, questions_state_repr
        # compute similarities -> (batch, timestep, 1)
        similarities_start = self.general_bilinear_start(paragraphs_seq_emb, questions_state_repr)
        similarities_end = self.general_bilinear_start(paragraphs_seq_emb, questions_state_repr)
        #print(f"INSIDE MODEL: similarities shape: {similarities.shape}") #DEBUG
        # --- Given a similarity score, predict P(start), P(end) ---
        # similarities flattened
        similarities_start = similarities_start.contiguous()
        similarities_start = similarities_start.view(-1, 1) # as similarity dim is 1 -> viewed shape is (batch*timestep, 1)
        start_scores = self.sim_to_start(similarities_start)
        start_logits = start_scores.view(batch_size, seq_len) # P(start)
        
        similarities_end = similarities_end.contiguous()
        similarities_end = similarities_end.view(-1, 1) # as similarity dim is 1 -> viewed shape is (batch*timestep, 1)
        end_scores = self.sim_to_end(similarities_end)
        end_logits = end_scores.view(batch_size, seq_len) # P(end)
        
        # if we view each sequence of tokens as a feature vector
        # we can interpret the start/end assignation problem as 
        # a classification with a variable number of classes
        # thus assume that our model outputs logits that will just be passed
        # to a softmax, to build a probable distribution of the start token
        return start_logits, end_logits

In [20]:
#torch.nn.functional.softmax(outs_mod[0])

In [30]:
# Define baseline model
model = LSTM_QA(embedding_model.vector_size, 128, 10).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001, amsgrad=True)

In [31]:
data_converter = DataConverter(embedding_model, paragraphs_spans_mapper)
datasetQA = CustomQADataset(data_converter, df, paragraphs_mapper)

In [32]:
train_data_loader = torch.utils.data.DataLoader(datasetQA, collate_fn = padder_collate_fn, batch_size=64, shuffle=True)

In [33]:
history = {"train_loss": [], "train_acc_start": [], "train_acc_end": []}
loop_start = timer()
# lr scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, threshold=0.01)
for epoch in range(50):
    train_dict = train_step(model, optimizer, loss_function, train_data_loader, device=device)
    eval_results = evaluate_model_on_data(model, evaluator, train_data_loader, paragraphs_mapper, device, debug=True)
    cur_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch: {epoch}, lr: {cur_lr}, Train loss: {train_dict["loss"]:.4f},  Train acc start: {train_dict["accuracy_start"]:.4f}, Train acc end: {train_dict["accuracy_end"]:.4f}, Time: {train_dict["time"]:.4f}')
    history["train_loss"].append(train_dict["loss"]);history["train_acc_start"].append(train_dict["accuracy_start"]);history["train_acc_end"].append(train_dict["accuracy_end"]);
    #history["val_loss"].append(val_dict["loss"]);history["val_acc"].append(val_dict["accuracy"]);
    #scheduler.step(val_dict["loss"])
    print(f"Evaluation Results: {eval_results}")
loop_end = timer()
print(f"Elapsed time: {(loop_end - loop_start):.4f}")

Might fail ['When', 'the', 'news', 'arrived', 'in', 'England', 'it', 'caused', 'an', 'outcry.', 'In', 'response,', 'a', 'combined', 'bounty', 'of', '£1,000', 'was', 'offered', 'for', "Every's", 'capture', 'by', 'the', 'Privy', 'Council', 'and', 'East', 'India', 'Company,', 'leading', 'to', 'the', 'first', 'worldwide', 'manhunt', 'in', 'recorded', 'history.', 'The', 'plunder', 'of', "Aurangzeb's", 'treasure', 'ship', 'had', 'serious', 'consequences', 'for', 'the', 'English', 'East', 'India', 'Company.', 'The', 'furious', 'Mughal', 'Emperor', 'Aurangzeb', 'ordered', 'Sidi', 'Yaqub', 'and', 'Nawab', 'Daud', 'Khan', 'to', 'attack', 'and', 'close', 'four', 'of', 'the', "company's", 'factories', 'in', 'India', 'and', 'imprison', 'their', 'officers,', 'who', 'were', 'almost', 'lynched', 'by', 'a', 'mob', 'of', 'angry', 'Mughals,', 'blaming', 'them', 'for', 'their', "countryman's", 'depredations,', 'and', 'threatened', 'to', 'put', 'an', 'end', 'to', 'all', 'English', 'trading', 'in', 'India.'

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch: 0, lr: 0.01, Train loss: 8.3310,  Train acc start: 0.0423, Train acc end: 0.0603, Time: 398.3827
Evaluation Results: {'exact_match': 5.293439422824462, 'f1': 9.209568529959142}


KeyboardInterrupt: 