## preprocess.py

In [1]:
import torch
import numpy as np
import pandas as pd
import pickle
import re, os, string, typing, gc, json
import spacy
from collections import Counter
nlp = spacy.load('en_core_web_sm')

def load_json(path):
    '''
    Loads the JSON file of the Squad dataset.
    Returns the json object of the dataset.
    '''
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    print("Length of data: ", len(data['data']))
    print("Data Keys: ", data['data'][0].keys())
    print("Title: ", data['data'][0]['title'])
    
    return data




def parse_data(data:dict)->list:
    '''
    Parses the JSON file of Squad dataset by looping through the
    keys and values and returns a list of dictionaries with
    context, query and label triplets being the keys of each dict.
    '''
    data = data['data']
    qa_list = []

    for paragraphs in data:

        for para in paragraphs['paragraphs']:
            context = para['context']

            for qa in para['qas']:
                
                id = qa['id']
                question = qa['question']
                
                for ans in qa['answers']:
                    answer = ans['text']
                    ans_start = ans['answer_start']
                    ans_end = ans_start + len(answer)
                    
                    qa_dict = {}
                    qa_dict['id'] = id
                    qa_dict['context'] = context
                    qa_dict['question'] = question
                    qa_dict['label'] = [ans_start, ans_end]

                    qa_dict['answer'] = answer
                    qa_list.append(qa_dict)    

    
    return qa_list



def filter_large_examples(df):
    '''
    Returns ids of examples where context lengths, query lengths and answer lengths are
    above a particular threshold. These ids can then be dropped from the dataframe. 
    This is explicitly mentioned in QANet but can be done for other models as well.
    '''
    
    ctx_lens = []
    query_lens = []
    ans_lens = []
    for index, row in df.iterrows():
        ctx_tokens = [w.text for w in nlp(row.context, disable=['parser','ner','tagger'])]
        if len(ctx_tokens)>400:
            ctx_lens.append(row.name)

        query_tokens = [w.text for w in nlp(row.question, disable=['parser','tagger','ner'])]
        if len(query_tokens)>50:
            query_lens.append(row.name)

        ans_tokens = [w.text for w in nlp(row.answer, disable=['parser','tagger','ner'])]
        if len(ans_tokens)>30:
            ans_lens.append(row.name)

        assert row.name == index
    
    return set(ans_lens + ctx_lens + query_lens)


def gather_text_for_vocab(dfs:list):
    '''
    Gathers text from contexts and questions to build a vocabulary.
    
    :param dfs: list of dataframes of SQUAD dataset.
    :returns: list of contexts and questions
    '''
    
    text = []
    total = 0
    for df in dfs:
        unique_contexts = list(df.context.unique())
        unique_questions = list(df.question.unique())
        total += df.context.nunique() + df.question.nunique()
        text.extend(unique_contexts + unique_questions)
    
    assert len(text) == total
    
    return text




def build_word_vocab(vocab_text):
    '''
    Builds a word-level vocabulary from the given text.
    
    :param list vocab_text: list of contexts and questions
    :returns 
        dict word2idx: word to index mapping of words
        dict idx2word: integer to word mapping
        list word_vocab: list of words sorted by frequency
    '''
    
    
    words = []
    for sent in vocab_text:
        for word in nlp(sent, disable=['parser','tagger','ner']):
            words.append(word.text)

    word_counter = Counter(words)
    word_vocab = sorted(word_counter, key=word_counter.get, reverse=True)
    print(f"raw-vocab: {len(word_vocab)}")
    word_vocab.insert(0, '<unk>')
    word_vocab.insert(1, '<pad>')
    print(f"vocab-length: {len(word_vocab)}")
    word2idx = {word:idx for idx, word in enumerate(word_vocab)}
    print(f"word2idx-length: {len(word2idx)}")
    idx2word = {v:k for k,v in word2idx.items()}
    
    
    return word2idx, idx2word, word_vocab


def build_char_vocab(vocab_text):
    '''
    Builds a character-level vocabulary from the given text.
    
    :param list vocab_text: list of contexts and questions
    :returns 
        dict char2idx: character to index mapping of words
        list char_vocab: list of characters sorted by frequency
    '''
    
    chars = []
    for sent in vocab_text:
        for ch in sent:
            chars.append(ch)

    char_counter = Counter(chars)
    char_vocab = sorted(char_counter, key=char_counter.get, reverse=True)
    print(f"raw-char-vocab: {len(char_vocab)}")
    high_freq_char = [char for char, count in char_counter.items() if count>=20]
    char_vocab = list(set(char_vocab).intersection(set(high_freq_char)))
    print(f"char-vocab-intersect: {len(char_vocab)}")
    char_vocab.insert(0,'<unk>')
    char_vocab.insert(1,'<pad>')
    char2idx = {char:idx for idx, char in enumerate(char_vocab)}
    print(f"char2idx-length: {len(char2idx)}")
    
    return char2idx, char_vocab



def context_to_ids(text, word2idx):
    '''
    Converts context text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    
    :param str text: context text to be converted
    :param dict word2idx: word to id mapping

    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''
    
    context_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    context_ids = [word2idx[word] for word in context_tokens]
    
    assert len(context_ids) == len(context_tokens)
    return context_ids



    
def question_to_ids(text, word2idx):
    '''
    Converts question text to their respective ids by mapping each word
    using word2idx. Input text is tokenized using spacy tokenizer first.
    
    :param str text: question text to be converted
    :param dict word2idx: word to id mapping
    :returns list context_ids: list of mapped ids
    
    :raises assertion error: sanity check
    
    '''
    
    question_tokens = [w.text for w in nlp(text, disable=['parser','tagger','ner'])]
    question_ids = [word2idx[word] for word in question_tokens]
    
    assert len(question_ids) == len(question_tokens)
    return question_ids
    


    
def test_indices(df, idx2word):
    '''
    Performs the tests mentioned above. This method also gets the start and end of the answers
    with respect to the context_ids for each example.
    
    :param dataframe df: SQUAD df
    :param dict idx2word: inverse mapping of token ids to words
    :returns
        list start_value_error: example idx where the start idx is not found in the start spans
                                of the text
        list end_value_error: example idx where the end idx is not found in the end spans
                              of the text
        list assert_error: examples that fail assertion errors. A majority are due to the above errors
        
    '''

    start_value_error = []
    end_value_error = []
    assert_error = []
    for index, row in df.iterrows():

        answer_tokens = [w.text for w in nlp(row['answer'], disable=['parser','tagger','ner'])]

        start_token = answer_tokens[0]
        end_token = answer_tokens[-1]
        
        context_span  = [(word.idx, word.idx + len(word.text)) 
                         for word in nlp(row['context'], disable=['parser','tagger','ner'])]

        starts, ends = zip(*context_span)

        answer_start, answer_end = row['label']

        try:
            start_idx = starts.index(answer_start)
        except:
            start_value_error.append(index)
        try:
            end_idx  = ends.index(answer_end)
        except:
            end_value_error.append(index)

        try:
            assert idx2word[row['context_ids'][start_idx]] == answer_tokens[0]
            assert idx2word[row['context_ids'][end_idx]] == answer_tokens[-1]
        except:
            assert_error.append(index)


    return start_value_error, end_value_error, assert_error



def get_error_indices(df, idx2word):
    
    start_value_error, end_value_error, assert_error = test_indices(df, idx2word)
    err_idx = start_value_error + end_value_error + assert_error
    err_idx = set(err_idx)
    print(f"Number of error indices: {len(err_idx)}")
    
    return err_idx



def index_answer(row, idx2word):
    '''
    Takes in a row of the dataframe or one training example and
    returns a tuple of start and end positions of answer by calculating 
    spans.
    '''
    
    context_span = [(word.idx, word.idx + len(word.text)) for word in nlp(row.context, disable=['parser','tagger','ner'])]
    starts, ends = zip(*context_span)
    
    answer_start, answer_end = row.label
    start_idx = starts.index(answer_start)
 
    end_idx  = ends.index(answer_end)
    
    ans_toks = [w.text for w in nlp(row.answer,disable=['parser','tagger','ner'])]
    ans_start = ans_toks[0]
    ans_end = ans_toks[-1]
    assert idx2word[row.context_ids[start_idx]] == ans_start
    assert idx2word[row.context_ids[end_idx]] == ans_end
    
    return [start_idx, end_idx]


## DrQA 

Implementation of model proposed in the paper: [Reading Wikipedia to Answer Open-Domain Questions](https://arxiv.org/abs/1704.00051) which is called DrQA by the authors. Specifically, DrQA is an end-to-end system for open domain question answering which involves an information retrieval system as well.  However we only implemented the deep learning model proposed by them.

In [2]:
import pandas as pd
import numpy as np
import torchtext
import torch
from torch import nn
import json, re, unicodedata, string, typing, time
import torch.nn.functional as F
import spacy
from collections import Counter
import pickle
from nltk import word_tokenize
%load_ext autoreload
%autoreload 2

In [3]:
# load dataset json files

train_data = load_json('/kaggle/input/drqna-squad/squad_train.json')
valid_data = load_json('/kaggle/input/drqna-squad/squad_dev.json')

# parse the json structure to return the data as a list of dictionaries

train_list = parse_data(train_data)
valid_list = parse_data(valid_data)

print('Train list len: ',len(train_list))
print('Valid list len: ',len(valid_list))

# converting the lists into dataframes

train_df = pd.DataFrame(train_list)
valid_df = pd.DataFrame(valid_list)

Length of data:  442
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  University_of_Notre_Dame
Length of data:  48
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  Super_Bowl_50
Train list len:  87599
Valid list len:  34726


In [4]:
def normalize_spaces(text):
    '''
    Removes extra white spaces from the context.
    '''
    text = re.sub(r'\s', ' ', text)
    return text

train_df.context = train_df.context.apply(normalize_spaces)
valid_df.context = valid_df.context.apply(normalize_spaces)

In [5]:
train_df.head()

Unnamed: 0,id,context,question,label,answer
0,5733be284776f41900661182,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,"[515, 541]",Saint Bernadette Soubirous
1,5733be284776f4190066117f,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,"[188, 213]",a copper statue of Christ
2,5733be284776f41900661180,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,"[279, 296]",the Main Building
3,5733be284776f41900661181,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,"[381, 420]",a Marian place of prayer and reflection
4,5733be284776f4190066117e,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,"[92, 126]",a golden statue of the Virgin Mary


In [6]:
# gather text to build vocabularies

%time vocab_text = gather_text_for_vocab([train_df, valid_df])
print("Number of sentences in dataset: ", len(vocab_text))

CPU times: user 447 ms, sys: 19 ms, total: 466 ms
Wall time: 468 ms
Number of sentences in dataset:  118852


In [7]:
# build word vocabulary

%time word2idx, idx2word, word_vocab = build_word_vocab(vocab_text)



raw-vocab: 110483
vocab-length: 110485
word2idx-length: 110485
CPU times: user 9min 23s, sys: 794 ms, total: 9min 24s
Wall time: 9min 24s


In [8]:
# numericalize context and questions for training and validation set

%time train_df['context_ids'] = train_df.context.apply(context_to_ids, word2idx=word2idx)
%time valid_df['context_ids'] = valid_df.context.apply(context_to_ids, word2idx=word2idx)

%time train_df['question_ids'] = train_df.question.apply(question_to_ids,  word2idx=word2idx)
%time valid_df['question_ids'] = valid_df.question.apply(question_to_ids,  word2idx=word2idx)

CPU times: user 17min 14s, sys: 322 ms, total: 17min 14s
Wall time: 17min 14s




CPU times: user 6min 58s, sys: 146 ms, total: 6min 59s
Wall time: 6min 59s




CPU times: user 4min 33s, sys: 84 ms, total: 4min 33s
Wall time: 4min 33s




CPU times: user 1min 47s, sys: 29.1 ms, total: 1min 47s
Wall time: 1min 47s


In [9]:
# get indices with tokenization errors and drop those indices 
train_err = get_error_indices(train_df, idx2word)
train_df.drop(train_err, inplace=True)
valid_err = get_error_indices(valid_df, idx2word)
valid_df.drop(valid_err, inplace=True)



Number of error indices: 923




Number of error indices: 362


In [10]:
# get start and end positions of answers from the context
# this is basically the label for training QA models

train_label_idx = train_df.apply(index_answer, axis=1, idx2word=idx2word)
valid_label_idx = valid_df.apply(index_answer, axis=1, idx2word=idx2word)

train_df['label_idx'] = train_label_idx
valid_df['label_idx'] = valid_label_idx



### Dump data to pickle files 
This ensures that we can directly access the preprocessed dataframe next time.

In [11]:
import pickle
with open('drqastoi.pickle','wb') as handle:
    pickle.dump(word2idx, handle)
    
train_df.to_pickle('drqatrain.pkl')
valid_df.to_pickle('drqavalid.pkl')

### Read data from pickle files

You only need to run the preprocessing once. Some preprocessing functions can take upto 3 mins. Therefore, pickling preprocessed data can save a lot of time.
Once the preprocessed files are saved, you can directly start from here.

In [12]:
train_df = pd.read_pickle('drqatrain.pkl')
valid_df = pd.read_pickle('drqavalid.pkl')
with open('drqastoi.pickle', 'rb') as handle:
    word2idx = pickle.load(handle)

## Dataset/ Dataloader

In [13]:
class SquadDataset:
    '''
    -Divides the dataframe in batches.
    -Pads the contexts and questions dynamically for each batch by padding 
     the examples to the maximum-length sequence in that batch.
    -Calculates masks for context and question.
    -Calculates spans for contexts.
    '''
    
    def __init__(self, data, batch_size):
        
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
    
    def get_span(self, text):
        
        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]

        return span

    def __len__(self):
        return len(self.data)
    
    def __iter__(self):
        '''
        Creates batches of data and yields them.
        
        Each yield comprises of:
        :padded_context: padded tensor of contexts for each batch 
        :padded_question: padded tensor of questions for each batch 
        :context_mask & question_mask: zero-mask for question and context
        :label: start and end index wrt context_ids
        :context_text,answer_text: used while validation to calculate metrics
        :context_spans: spans of context text
        :ids: question_ids used in evaluation
        '''
        
        for batch in self.data:
                            
            spans = []
            context_text = []
            answer_text = []
            
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            
            for ctx in batch.context:
                context_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i,: len(ques)] = torch.LongTensor(ques)
                
            
            label = torch.LongTensor(list(batch.label_idx))
            context_mask = torch.eq(padded_context, 1)
            question_mask = torch.eq(padded_question, 1)
            
            ids = list(batch.id)  
            
            yield (padded_context, padded_question, context_mask, 
                   question_mask, label, context_text, answer_text, ids)
            
            

In [14]:
train_dataset = SquadDataset(train_df, 32)

In [15]:
valid_dataset = SquadDataset(valid_df, 32)

In [16]:
a = next(iter(train_dataset))



In [17]:
a[0].shape, a[1].shape, a[2].shape, a[3].shape, a[4].shape

(torch.Size([32, 253]),
 torch.Size([32, 19]),
 torch.Size([32, 253]),
 torch.Size([32, 19]),
 torch.Size([32, 2]))

## Model

An input example during training is comprised of 
* a paragraph / context $p$ consisting of $l$ tokens { $p_{1}$, $p_{2}$,..., $p_{l}$ }
* a question $q$ consisting of $m$ tokens { $q_{1}$, $q_{2}$,..., $q_{m}$ }
* a start and and end position that comes from the context itself. More specifically, the start and end indices of the answer from the context  

### Word Embedding
Pretrained Gove embeddings are used to train the model



In [18]:
def create_glove_matrix():
    '''
    Parses the glove word vectors text file and returns a dictionary with the words as
    keys and their respective pretrained word vectors as values.

    '''
    glove_dict = {}
    with open("/kaggle/input/glove840b300dtxt/glove.840B.300d.txt", "r", encoding="utf-8") as f:
        for line in f:
            values = line.split(' ')
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            glove_dict[word] = vector

    f.close()
    
    return glove_dict

In [19]:
glove_dict = create_glove_matrix()

In [20]:
def create_word_embedding(glove_dict):
    '''
    Creates a weight matrix of the words that are common in the GloVe vocab and
    the dataset's vocab. Initializes OOV words with a zero vector.
    '''
    weights_matrix = np.zeros((len(word_vocab), 300))
    words_found = 0
    for i, word in enumerate(word_vocab):
        try:
            weights_matrix[i] = glove_dict[word]
            words_found += 1
        except:
            pass
    return weights_matrix, words_found

In [21]:
weights_matrix, words_found = create_word_embedding(glove_dict)

In [22]:
print("Total words found in glove vocab: ", words_found)

Total words found in glove vocab:  91272


In [23]:
np.save('drqaglove_vt.npy',weights_matrix)

### Align Question Embedding

The paper has different encoding procedures for the context and the question. The context/paragraph comprises of following additional features:

* exact match : encodes a binary feature if $p$ can be exactly matched to one word in question in its original, lemma or lowercase form
* token features : Includes POS, NER and TF of context tokens and
* aligned question embedding ($f_{align}$) .  

In this implementation we have only implemented the aligned question embedding.
$f_{align}$ has been formulated as shown below:

$$ f_{align} = \sum_{j}a_{i,j}E(q_{j}) $$ 

where $E()$ represents the glove embeddings and

where $\alpha()$ is a single dense layer with relu non-linearity. 

#### Intuition
This feature enables the model to understand what portion of the context is more important or relevant with respect to the question. The products of projections taken at token level ensure a higher value when similar words from the question and context are multiplied. Quoting the paper,
> *these features add soft alignments between similar but non-identical words (e.g., car and vehicle).* 

This is achieved via backpropation and training the weights of the dense layer. While this might seem a bit weird initially, we have to trust the process of backpropogation.   

While implementing, we first calculate the projections of context and question vectors. We then use `torch.bmm` to calculate the product in the numerator of $a_{i,j}$, mask the product and then pass it through the softmax function to get $a_{i,j}$. Finally, we multiply this with the question embeddings. The output of this layer is an additional context embedding which is then concatenated with the glove embeddings.

In [24]:
class AlignQuestionEmbedding(nn.Module):
    
    def __init__(self, input_dim):        
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, input_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, context, question, question_mask):
        
        # context = [bs, ctx_len, emb_dim]
        # question = [bs, qtn_len, emb_dim]
        # question_mask = [bs, qtn_len]
    
        ctx_ = self.linear(context)
        ctx_ = self.relu(ctx_)
        # ctx_ = [bs, ctx_len, emb_dim]
        
        qtn_ = self.linear(question)
        qtn_ = self.relu(qtn_)
        # qtn_ = [bs, qtn_len, emb_dim]
        
        qtn_transpose = qtn_.permute(0,2,1)
        # qtn_transpose = [bs, emb_dim, qtn_len]
        
        align_scores = torch.bmm(ctx_, qtn_transpose)
        # align_scores = [bs, ctx_len, qtn_len]
        
        qtn_mask = question_mask.unsqueeze(1).expand(align_scores.size())
        # qtn_mask = [bs, 1, qtn_len] => [bs, ctx_len, qtn_len]
        
        # Fills elements of self tensor(align_scores) with value(-float(inf)) where mask is True. 
        # The shape of mask must be broadcastable with the shape of the underlying tensor.
        align_scores = align_scores.masked_fill(qtn_mask == 1, -float('inf'))
        # align_scores = [bs, ctx_len, qtn_len]
        
        align_scores_flat = align_scores.view(-1, question.size(1))
        # align_scores = [bs*ctx_len, qtn_len]
        
        alpha = F.softmax(align_scores_flat, dim=1)
        alpha = alpha.view(-1, context.shape[1], question.shape[1])
        # alpha = [bs, ctx_len, qtn_len]
        
        align_embedding = torch.bmm(alpha, question)
        # align = [bs, ctx_len, emb_dim]
        
        return align_embedding

## Stacked BiLSTM

The paragraph/context encoding which now has two features (glove and $f_{align}$) is then passed to a multilayer (3 layers) bidirectional LSTM. According to the paper,

> *Speciﬁcally, we choose to use a multi-layer bidirectional long short-term memory network (LSTM), and take the concatenation of each layer’s hidden units in the end. *

To achieve this functionality we cannot directly use the pytorch recurrent layers. Every recurrent layer in pytorch returns a tuple `[output, hidden]` where `output` holds the hidden states of all the timesteps from the __last layer only__. We need to access the hidden states of intermediate layers and then concatenate them at the end.
The following figure illustrates this point in more detail.

<!-- <img src="images/bilstm.png" width="700" height="600"/> -->

This figure shows a 3-layer bidirectional LSTM with an input sequence of size $n$. The green blocks denote the forward LSTMs and the blue blocks backward. Each block is labelled with the value that it calculates. The subscript denotes the time-step and the superscript denotes the depth or the layer-number. For example $hf_{1}^{(0)}$ calculates the first hidden state in forward LSTM in the first layer. 

As highlighted in the diagram, we need the intermediate hidden states passed between the layers along with the final output. To create this in code, we create a `nn.ModuleList` and add 3 LSTM layers to it. The input size of the first layer remains the same but for subsequent LSTMs the input size must be twice the hidden size. This is because the `output` of the first LSTM will have the dimension of `[batch_size, seq_len, hidden_size*num_directions]` and `num_directions` is 2 in our case. In the forward method, we loop through the LSTMs, store the hidden states of each layer and finally return the concatenated output. 



In [25]:
class StackedBiLSTM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        
        super().__init__()
        
        self.dropout = dropout
        
        self.num_layers = num_layers
        
        self.lstms = nn.ModuleList()
        
        for i in range(self.num_layers):
            
            input_dim = input_dim if i == 0 else hidden_dim * 2
            
            self.lstms.append(nn.LSTM(input_dim, hidden_dim,
                                      batch_first=True, bidirectional=True))
           
    
    def forward(self, x):
        # x = [bs, seq_len, feature_dim]

        outputs = [x]
        for i in range(self.num_layers):

            lstm_input = outputs[-1]
            lstm_out = F.dropout(lstm_input, p=self.dropout)
            lstm_out, (hidden, cell) = self.lstms[i](lstm_input)
           
            outputs.append(lstm_out)

    
        output = torch.cat(outputs[1:], dim=2)
        # [bs, seq_len, num_layers*num_dir*hidden_dim]
        
        output = F.dropout(output, p=self.dropout)
      
        return output

## Linear Attention Layer
This layer is used to calculate the importance of each word in the question. This can be achieved by simply taking a softmax over the input. However to add more learning capacity to the model, the inputs are multiplied by a trainable weight vector $w$ and then passed through a softmax function.  

Essentially the layer is performing "attention" on inputs. The $w$ in code is characterized by a linear layer.

In [26]:
class LinearAttentionLayer(nn.Module):
    
    def __init__(self, input_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, question, question_mask):
        
        # question = [bs, qtn_len, input_dim] = [bs, qtn_len, bi_lstm_hid_dim]
        # question_mask = [bs,  qtn_len]
        
        qtn = question.view(-1, question.shape[-1])
        # qtn = [bs*qtn_len, hid_dim]
        
        attn_scores = self.linear(qtn)
        # attn_scores = [bs*qtn_len, 1]
        
        attn_scores = attn_scores.view(question.shape[0], question.shape[1])
        # attn_scores = [bs, qtn_len]
        
        attn_scores = attn_scores.masked_fill(question_mask == 1, -float('inf'))
        
        alpha = F.softmax(attn_scores, dim=1)
        # alpha = [bs, qtn_len]
        
        return alpha
        

The following function just multiplies the weights calculated in the previous layer by the outputs of the question bilstm layer. This allows the model to assign higher values to important words in each question.

$$ q = \sum_{j} b_{j} q_{j} $$

In [27]:
def weighted_average(x, weights):
    # x = [bs, len, dim]
    # weights = [bs, len]
    
    weights = weights.unsqueeze(1)
    # weights = [bs, 1, len]
    
    w = weights.bmm(x).squeeze(1)
    # w = [bs, 1, dim] => [bs, dim]
    
    return w

## Bilinear Attention in NLP

$$ e_{t} = s^{T} W h_{t}$$
where $W$ is a trainable weight vector.
This is the method used in this paper to predict the start and end position of the answer from the context.    


To implement this layer, we characterise $W$ by a linear layer.
First the linear layer is applied to the question, which is equivalent to the product $W.q$. This product is then multiplied by the context using `torch.bmm`.   
Note that softmax is not taken over here to get the weights. This is taken care of when we calculate the loss using cross entropy. The following layer does not actually calculate the attention as a weighted sum. It just uses the bilinear term's representation to predict the span..

In [28]:
class BilinearAttentionLayer(nn.Module):
    
    def __init__(self, context_dim, question_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(question_dim, context_dim)
        
    def forward(self, context, question, context_mask):
        
        # context = [bs, ctx_len, ctx_hid_dim] = [bs, ctx_len, hid_dim*6] = [bs, ctx_len, 768]
        # question = [bs, qtn_hid_dim] = [bs, qtn_len, 768]
        # context_mask = [bs, ctx_len]
        
        qtn_proj = self.linear(question)
        # qtn_proj = [bs, ctx_hid_dim]
        
        qtn_proj = qtn_proj.unsqueeze(2)
        # qtn_proj = [bs, ctx_hid_dim, 1]
        
        scores = context.bmm(qtn_proj)
        # scores = [bs, ctx_len, 1]
        
        scores = scores.squeeze(2)
        # scores = [bs, ctx_len]
        
        scores = scores.masked_fill(context_mask == 1, -float('inf'))
        
        #alpha = F.log_softmax(scores, dim=1)
        # alpha = [bs, ctx_len]
        
        return scores

## Putting it together

The following module brings all the components discussed so far together. It takes in the context and question tokens as inputs and returns the start and end positions of the answer from the context.  

* Aligned question embedding is calculated for the context vector and concatenated (using `torch.cat`) to the context representation. If $d$ is the embedding dimension then context $\epsilon$ $R^{2d}$ and question $\epsilon$ $R^{d}$.
* Context and question representations are then passed to bilstm layers to obtain tensors of dimension `[batch_size, seq_len, hidden_dim*6]` since the LSTM is bidirectional and there are 3 layers of it.
* The embedded question is also passed through the linear attention layer and a weighted sum of its output is taken with the biLSTM output.
* Both these representations are finally passed through the bilinear attention layer to predict the start and end position of the answer.   

In [29]:
class DocumentReader(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, num_directions, dropout, device):
        
        super().__init__()
        
        self.device = device
        
        #self.embedding = self.get_glove_embedding()
        
        self.context_bilstm = StackedBiLSTM(embedding_dim * 2, hidden_dim, num_layers, dropout)
        
        self.question_bilstm = StackedBiLSTM(embedding_dim, hidden_dim, num_layers, dropout)
        
        self.glove_embedding = self.get_glove_embedding()
        
        def tune_embedding(grad, words=1000):
            grad[words:] = 0
            return grad
        
        self.glove_embedding.weight.register_hook(tune_embedding)
        
        self.align_embedding = AlignQuestionEmbedding(embedding_dim)
        
        self.linear_attn_question = LinearAttentionLayer(hidden_dim*num_layers*num_directions) 
        
        self.bilinear_attn_start = BilinearAttentionLayer(hidden_dim*num_layers*num_directions, 
                                                          hidden_dim*num_layers*num_directions)
        
        self.bilinear_attn_end = BilinearAttentionLayer(hidden_dim*num_layers*num_directions,
                                                        hidden_dim*num_layers*num_directions)
        
        self.dropout = nn.Dropout(dropout)
   
        
    def get_glove_embedding(self):
        
        weights_matrix = np.load('drqaglove_vt.npy')
        num_embeddings, embedding_dim = weights_matrix.shape
        embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix).to(self.device),freeze=False)

        return embedding
    
    
    def forward(self, context, question, context_mask, question_mask):
       
        # context = [bs, len_c]
        # question = [bs, len_q]
        # context_mask = [bs, len_c]
        # question_mask = [bs, len_q]
        
        
        ctx_embed = self.glove_embedding(context)
        # ctx_embed = [bs, len_c, emb_dim]
        
        ques_embed = self.glove_embedding(question)
        # ques_embed = [bs, len_q, emb_dim]
        

        ctx_embed = self.dropout(ctx_embed)
     
        ques_embed = self.dropout(ques_embed)
             
        align_embed = self.align_embedding(ctx_embed, ques_embed, question_mask)
        # align_embed = [bs, len_c, emb_dim]  
        
        ctx_bilstm_input = torch.cat([ctx_embed, align_embed], dim=2)
        # ctx_bilstm_input = [bs, len_c, emb_dim*2]
                
        ctx_outputs = self.context_bilstm(ctx_bilstm_input)
        # ctx_outputs = [bs, len_c, hid_dim*layers*dir] = [bs, len_c, hid_dim*6]
       
        qtn_outputs = self.question_bilstm(ques_embed)
        # qtn_outputs = [bs, len_q, hid_dim*6]
    
        qtn_weights = self.linear_attn_question(qtn_outputs, question_mask)
        # qtn_weights = [bs, len_q]
            
        qtn_weighted = weighted_average(qtn_outputs, qtn_weights)
        # qtn_weighted = [bs, hid_dim*6]
        
        start_scores = self.bilinear_attn_start(ctx_outputs, qtn_weighted, context_mask)
        # start_scores = [bs, len_c]
         
        end_scores = self.bilinear_attn_end(ctx_outputs, qtn_weighted, context_mask)
        # end_scores = [bs, len_c]
        
      
        return start_scores, end_scores

###  Hyperparameters

> *We use 3-layer bidirectional LSTMs with h = 128 hidden units for both paragraph and question encoding. Dropout with p = 0.3 is applied to word embeddings and all the hidden units of LSTMs. *

In [30]:
device = torch.device('cuda')
HIDDEN_DIM = 128
EMB_DIM = 300
NUM_LAYERS = 3
NUM_DIRECTIONS = 2
DROPOUT = 0.3
device = torch.device('cuda')

model = DocumentReader(HIDDEN_DIM,
                       EMB_DIM, 
                       NUM_LAYERS, 
                       NUM_DIRECTIONS, 
                       DROPOUT, 
                       device).to(device)

## Training

In [31]:
optimizer = torch.optim.Adamax(model.parameters())

In [32]:
def count_parameters(model):
    '''Returns the number of trainable parameters in the model.'''
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 37,186,649 trainable parameters


In [33]:
def train(model, train_dataset):
    '''
    Trains the model.
    '''
    
    print("Starting training ........")
    
    train_loss = 0.
    batch_count = 0
    
    # put the model in training mode
    model.train()
    
    # iterate through training data
    for batch in train_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch: {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, ctx, ans, ids = batch
        
        # place the tensors on GPU
        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)
        
        # forward pass, get the predictions
        preds = model(context, question, context_mask, question_mask)

        start_pred, end_pred = preds
        
        # separate labels for start and end position
        start_label, end_label = label[:,0], label[:,1]
        
        # calculate loss
        loss = F.cross_entropy(start_pred, start_label) + F.cross_entropy(end_pred, end_label)
        
        # backward pass, calculates the gradients
        loss.backward()
        
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        
        # update the gradients
        optimizer.step()
        
        # zero the gradients to prevent them from accumulating
        optimizer.zero_grad()

        train_loss += loss.item()

    return train_loss/len(train_dataset)

In [34]:
def valid(model, valid_dataset):
    '''
    Performs validation.
    '''
    
    print("Starting validation .........")
   
    valid_loss = 0.

    batch_count = 0
    
    f1, em = 0., 0.
    
    # puts the model in eval mode. Turns off dropout
    model.eval()
    
    predictions = {}
    
    for batch in valid_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, context_text, answers, ids = batch

        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)

        with torch.no_grad():

            preds = model(context, question, context_mask, question_mask)

            p1, p2 = preds

            y1, y2 = label[:,0], label[:,1]

            loss = F.cross_entropy(p1, y1) + F.cross_entropy(p2, y2)

            valid_loss += loss.item()

            
            # get the start and end index positions from the model preds
            
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
            # stack predictions
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred
            
            
            
    em, f1 = evaluate(predictions)            
    return valid_loss/len(valid_dataset), em, f1
                

In [35]:
def evaluate(predictions):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.
    
    
    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''
    with open('/kaggle/input/drqna-squad/squad_dev.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
        
    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1



In [36]:
def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    Returns maximum value of metrics for predicition by model against
    multiple ground truths.
    
    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.
                            
    
    '''
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    '''
    Returns f1 score of two strings.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    Returns exact_match_score of two strings.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [37]:

train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    
    start_time = time.time()
    
    train_loss = train(model, train_dataset)
    torch.save(model.state_dict(), 'DrQA_model')

    valid_loss, em, f1 = valid(model, valid_dataset)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("====================================================================================")
    

Epoch 1
Starting training ........




Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Starting batch 1000
Epoch train loss : 4.89303082991192| Time: 26m 18s
Epoch valid loss: 4.006683577816357
Epoch EM: 46.72658467360454
Epoch F1: 59.293232057018805
Epoch 2
Starting training ........
Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Starting batch 1000
Epoch train loss : 3.6211951912967604| Time: 25m 54s
Epoch valid loss: 3.6315393029977487
Epoch EM: 52.70577105014191
Epoch F1: 64.62190699698601
Epoch 3
Starting training ........
Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Starting batch 1000
Epoch train loss

## References

* Papers read/referenced
    1. https://arxiv.org/abs/1704.00051
    2. https://arxiv.org/abs/1606.02858
    3. https://arxiv.org/abs/1409.0473
* Other helpful links
    1. https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
    2. https://github.com/facebookresearch/DrQA
    3. https://github.com/hitvoice/DrQA. Special thanks to [Runqi Yang](https://github.com/hitvoice) who helped me clarify some doubts with respect to preprocessing the SQUAD dataset.
    4. https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-3-attention-92352bbdcb07. Good introduction to attention.
    5. https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1184/lectures/lecture10.pdf. The attention section of this notebook is largely inspired and derived from these slides.
* Following links are related to debugging neural nets. Something on which I was stuck for quite some time during training these models.
    1. https://datascience.stackexchange.com/questions/410/choosing-a-learning-rate
    2. https://www.jeremyjordan.me/nn-learning-rate/
    3. https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0
    4. https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1
    5. https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21
    6. https://arxiv.org/abs/1708.07120
    7. https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html