# SQuAD Q&A

This notebook contains training scripts for models to be used for the question answering problem on the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) v1.1 dataset, which consists on selecting a possible answer to the given question as a span of words in the given context paragraph. The newest version (v2.0) of the dataset also contains unanswerable questions, but the one on which we worked on (v1.1) does not.

## Colab requirements

Before restarting runtime (remember to select GPU runtime)$\dots$

In [None]:
!pip install -r squad-question-answering/init/base_requirements.txt

Cloning into 'squad-question-answering'...
remote: Enumerating objects: 459, done.[K
remote: Total 459 (delta 0), reused 0 (delta 0), pack-reused 459[K
Receiving objects: 100% (459/459), 24.36 MiB | 18.31 MiB/s, done.
Resolving deltas: 100% (270/270), done.


In [None]:
import os, sys
sys.path.insert(0, "squad-question-answering")
os.chdir("squad-question-answering")
sys.path.insert(0, "src")
import os
import json
from functools import partial
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import transformers
import collections
import training
import utils
import layer
import layer_utils
import config
import torch
from transformers.trainer_utils import set_seed
from torch.utils.data import Dataset
from operator import attrgetter
from tokenizers import BertWordPieceTokenizer, Tokenizer
from tokenizers.implementations import BaseTokenizer
from tokenizers.models import WordLevel
from tokenizers.normalizers import Lowercase, Sequence, Strip, StripAccents
from tokenizers.pre_tokenizers import Punctuation
from tokenizers.pre_tokenizers import Sequence as PreSequence
from tokenizers.pre_tokenizers import Whitespace

%load_ext autoreload
%autoreload 2

# Check the current configuration variables
[(item, getattr(config, item)) for item in dir(config) if not item.startswith("__")]

# Since we are using wandb, we have to initialize few variables so that we can use them throughout the book
%env WANDB_PROJECT=squad-qa
%env WANDB_ENTITY=wadaboa
%env WANDB_MODE=online
%env WANDB_RESUME=never
%env WANDB_WATCH=false
%env WANDB_SILENT=true

!wandb login
set_seed(config.RANDOM_SEED)
DEVICE = utils.get_device()

class SquadTorchDataset(Dataset):
    def __init__(self, df):
        self.df = df.copy()
        self.df = self.df.reset_index(drop=True)
    def __len__(self):
        return self.df.shape[0]
    def __getitem__(self, index):
        assert isinstance(index, int)
        question = self.df.loc[index, "question"]
        context = self.df.loc[index, "context"]
        if "answer" not in self.df.columns:
            return index, question, context

        answer_start = self.df.loc[index, "answer_start"]
        answer_end = self.df.loc[index, "answer_end"]
        return index, question, context, answer_start, answer_end
    
class SquadDataset:
    JSON_RECORD_PATH = ["data", "paragraphs", "qas", "answers"]

    def __init__(
        self, train_set_path=None, test_set_path=None, subset=1.0,
    ):
        # Save training and test set paths
        self.train_set_path = train_set_path
        self.test_set_path = test_set_path

        # Process the training set
        self.raw_train_df = None
        if self.train_set_path is not None:
            assert os.path.exists(
                self.train_set_path
            ), "Missing SQuAD training set .json file"
            self.train_df_path = f"{os.path.splitext(self.train_set_path)[0]}.pkl"
            self.raw_train_df = self._load_dataset(
                self.train_set_path, self.train_df_path
            )
            if subset < 1.0:
                self.raw_train_df = self._get_portion(self.raw_train_df, subset)

        # Process the test set
        self.raw_test_df = None
        if self.test_set_path:
            assert os.path.exists(
                self.test_set_path
            ), "Missing SQuAD testing set .json file"
            self.test_df_path = f"{os.path.splitext(self.test_set_path)[0]}.pkl"
            self.raw_test_df = self._load_dataset(self.test_set_path, self.test_df_path)
            self.test_has_labels = "answer" in self.raw_test_df.columns
            if subset < 1.0:
                self.raw_test_df = self._get_portion(self.raw_test_df, subset)

    def _add_end_index(self, df):
        ans_end = []
        for index, row in df.iterrows():
            t = row.answer
            s = row.answer_start
            ans_end.append(s + len(t))
        df["answer_end"] = ans_end
        return df

    def _load_dataset(self, dataset_path, dataframe_path):
        if os.path.exists(dataframe_path):
            try:
                return pd.read_pickle(dataframe_path)
            except ValueError:
                pass

        # Check if the dataset has labels or not
        json_file = json.loads(open(dataset_path).read())
        if (len(pd.json_normalize(json_file, self.JSON_RECORD_PATH[:-1]).loc[
                0, "answers"])> 0):
            df = self._load_dataset_with_labels(json_file)
        else:
            df = self._load_dataset_no_labels(json_file)
        df = df.reset_index(drop=True)
        df.to_pickle(dataframe_path)
        return df

    def _load_dataset_with_labels(self, json_file):
        df = pd.json_normalize(
            json_file, self.JSON_RECORD_PATH, meta=[["data", "title"]])
        df_questions = pd.json_normalize(
            json_file, self.JSON_RECORD_PATH[:-1], meta=[["data", "title"]])
        df_contexts = pd.json_normalize(
            json_file, self.JSON_RECORD_PATH[:-2], meta=[["data", "title"]])
        contexts = np.repeat(df_contexts["context"].values, df_contexts.qas.str.len())
        contexts = np.repeat(contexts, df_questions["answers"].str.len())
        df["context"] = contexts
        df["question_id"] = np.repeat(
            df_questions["id"].values, df_questions["answers"].str.len())
        df["question"] = np.repeat(
            df_questions["question"].values, df_questions["answers"].str.len())
        df["context_id"] = df["context"].factorize()[0]
        df.rename(columns={"data.title": "title", "text": "answer"}, inplace=True)
        df = self._add_end_index(df)
        df = df.drop_duplicates()
        return df

    def _load_dataset_no_labels(self, json_file):
        df_questions = pd.json_normalize(
            json_file, self.JSON_RECORD_PATH[:-1], meta=[["data", "title"]])
        df_contexts = pd.json_normalize(
            json_file, self.JSON_RECORD_PATH[:-2], meta=[["data", "title"]])
        df_questions["context"] = np.repeat(
            df_contexts["context"].values, df_contexts.qas.str.len())
        df_questions["context_id"] = df_questions["context"].factorize()[0]

        # Rename columns
        df_questions.rename(
            columns={"data.title": "title", "id": "question_id"}, inplace=True)
        if "answers" in df_questions.columns:
            df_questions = df_questions.drop("answers", axis="columns")

        return df_questions

    def _get_portion(self, df, subset=1.0):
        amount = int(df.shape[0] * subset)
        random_indexes = np.random.choice(
            np.arange(df.shape[0]), size=amount, replace=False
        )
        return df.iloc[random_indexes].reset_index(drop=True)

# Now we need to load the raw data 
DATA_FOLDER = os.path.join(os.getcwd(), "data")
TRAIN_DATA_FOLDER = os.path.join(DATA_FOLDER, "training")
TRAIN_SET_PATH = os.path.join(TRAIN_DATA_FOLDER, "training_set.json")
TEST_DATA_FOLDER = os.path.join(DATA_FOLDER, "testing")
TEST_SET_PATH = os.path.join(TEST_DATA_FOLDER, "test_set.json")

squad_dataset = SquadDataset(
    train_set_path=TRAIN_SET_PATH,
    test_set_path=TEST_SET_PATH,
    subset=config.DATA_SUBSET,
)

# Now we will load the default training parameters, such as the batch size, logging frequency
# and where we have to save the model checkpoints
TRAINER_ARGS = utils.get_default_trainer_args()

# we will load an embedding matrix using the Gensim API and use corresponding matrix as the weight block
embedding_model, vocab = utils.load_embedding_model(
    config.EMBEDDING_MODEL_NAME,
    embedding_dimension=config.EMBEDDING_DIMENSION,
    unk_token=config.UNK_TOKEN,
    pad_token=config.PAD_TOKEN,
)
embedding_layer = layer_utils.get_embedding_module(
    embedding_model, pad_id=vocab[config.PAD_TOKEN]
)
print(embedding_layer)

#Now we will define tokenizers
class SquadTokenizer:
    ENCODING_ATTR = [
        "ids",
        "type_ids",
        "tokens",
        "offsets",
        "attention_mask",
        "special_tokens_mask",
        "overflowing",
        "word_ids",
    ]
    ENCODING_ATTR_ID = {k: i for i, k in enumerate(ENCODING_ATTR)}
    ATTRGETTER = attrgetter(*ENCODING_ATTR)

    def __init__(self, device="cpu"):
        self.device = device

    def tokenize(self, inputs, entity=None, special=False):
        tokenizer = self.select_tokenizer(entity)
        tokenizer_padding = tokenizer.padding
        if not special:
            tokenizer.no_padding()
        outputs = tokenizer.encode_batch(inputs, add_special_tokens=special)
        tokenizer.enable_padding(**tokenizer_padding)
        return outputs

    def detokenize(self, inputs, entity=None, special=True):
        tokenizer = self.select_tokenizer(entity)
        return tokenizer.decode_batch(inputs, skip_special_tokens=not special)

    def get_pad_token_id(self):
        tokenizer = self.select_tokenizer(entity="context")
        return tokenizer.padding["pad_id"]

    def find_tokenized_answer_indexes(self, offsets, starts, ends):
        batch_size = len(starts)
        max_answers = max([len(row) for row in starts])
        indexes = torch.full((batch_size, max_answers, 2), -100, device=self.device)
        for i, (start, end) in enumerate(zip(starts, ends)):
            for j, (s, e) in enumerate(zip(start, end)):
                start_index = torch.nonzero(offsets[i, :, 0] == s)
                end_index = torch.nonzero(offsets[i, :, 1] == e)
                if len(start_index) > 0 and len(end_index) > 0:
                    indexes[i, j, :] = torch.tensor(
                        [start_index[0], end_index[0]], device=self.device
                    )
        return indexes

    def find_subword_indexes(self, word_ids):
        start_mask = torch.full(
            (len(word_ids), len(word_ids[0])), False, device=self.device)
        end_mask = torch.full_like(start_mask, False, device=self.device)

        for i, word_id in enumerate(word_ids):

            if word_id[0] != None:
                start_mask[i, 0] = True

            for j in range(1, len(word_id)):
                if word_id[j] != word_id[j - 1]:
                    if word_id[j] != None:
                        start_mask[i, j] = True
                    if word_id[j - 1] != None:
                        end_mask[i, j] = True

            if word_id[-1] != None:
                end_mask[i, -1] = True

        return start_mask, end_mask

    def select_tokenizer(self, entity=None):
        raise NotImplementedError()

    def __call__(self, inputs):
        raise NotImplementedError()
        
class RecurrentSquadTokenizer(SquadTokenizer):
    def __init__(self, question_tokenizer, context_tokenizer, device="cpu"):
        super().__init__(device=device)
        assert isinstance(question_tokenizer, Tokenizer)
        assert isinstance(context_tokenizer, Tokenizer)
        self.question_tokenizer = question_tokenizer
        self.context_tokenizer = context_tokenizer

    def select_tokenizer(self, entity=None):
        assert entity in ("question", "context")
        return (
            self.context_tokenizer if entity == "context" else self.question_tokenizer)

    def __call__(self, inputs):
        zipped_inputs = tuple(zip(*inputs))
        if len(zipped_inputs) > 3:
            (indexes, questions, contexts, answers_start, answers_end) = zipped_inputs
            testing = False
        else:
            (indexes, questions, contexts) = zipped_inputs
            testing = True
        tokenized_questions = self.tokenize(questions, entity="question", special=True)
        tokenized_contexts = self.tokenize(contexts, entity="context", special=True)
        qattr = list(zip(*[self.ATTRGETTER(e) for e in tokenized_questions]))
        cattr = list(zip(*[self.ATTRGETTER(e) for e in tokenized_contexts]))
        batch = {
            "question_ids": torch.tensor(
                qattr[self.ENCODING_ATTR_ID["ids"]], device=self.device),
            "question_type_ids": torch.tensor(
                qattr[self.ENCODING_ATTR_ID["type_ids"]], device=self.device),
            "question_attention_mask": torch.tensor(
                qattr[self.ENCODING_ATTR_ID["attention_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "question_special_tokens_mask": torch.tensor(
                qattr[self.ENCODING_ATTR_ID["special_tokens_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "context_ids": torch.tensor(
                cattr[self.ENCODING_ATTR_ID["ids"]], device=self.device),
            "context_type_ids": torch.tensor(
                cattr[self.ENCODING_ATTR_ID["type_ids"]], device=self.device),
            "context_attention_mask": torch.tensor(
                cattr[self.ENCODING_ATTR_ID["attention_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "context_special_tokens_mask": torch.tensor(
                cattr[self.ENCODING_ATTR_ID["special_tokens_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "context_offsets": torch.tensor(
                cattr[self.ENCODING_ATTR_ID["offsets"]], device=self.device),
            "indexes": torch.tensor(indexes, dtype=torch.long, device=self.device),
        }

        # Add custom info to the batch dict
        batch["context_offsets"] = torch.where(
            batch["context_attention_mask"].unsqueeze(-1).repeat(1, 1, 2),
            batch["context_offsets"],
            -100,)
        batch["question_lenghts"] = torch.count_nonzero(
            batch["question_attention_mask"], dim=1)
        batch["context_lenghts"] = torch.count_nonzero(
            batch["context_attention_mask"], dim=1)
        (
            batch["subword_start_mask"],
            batch["subword_end_mask"],
        ) = self.find_subword_indexes(cattr[self.ENCODING_ATTR_ID["word_ids"]])

        if not testing:
            batch["answers"] = self.find_tokenized_answer_indexes(
                batch["context_offsets"], answers_start, answers_end
            )
        return batch

def get_recurrent_tokenizer(vocab, max_context_tokens, unk_token, pad_token, device="cpu"):
    question_tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token))
    question_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
    question_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
    question_tokenizer.enable_padding(
        direction="right", pad_id=vocab[pad_token], pad_type_id=1, pad_token=pad_token)
    context_tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token))
    context_tokenizer.normalizer = Sequence([StripAccents(), Lowercase(), Strip()])
    context_tokenizer.pre_tokenizer = PreSequence([Whitespace(), Punctuation()])
    context_tokenizer.enable_padding(
        direction="right",
        pad_id=vocab[pad_token],
        pad_type_id=1,
        pad_token=pad_token,)
    context_tokenizer.enable_truncation(max_context_tokens)
    return RecurrentSquadTokenizer(question_tokenizer, context_tokenizer, device=device)


recurrent_tokenizer = get_recurrent_tokenizer(
    vocab,
    config.MAX_CONTEXT_TOKENS,
    config.UNK_TOKEN,
    config.PAD_TOKEN,
    device=DEVICE,)

#we will now create a class called as SquadDataManager which acts as a pre-processor
class SquadDataManager:
    def __init__(self, dataset, tokenizer, val_split=0.2, device="cpu"):
        assert isinstance(dataset, SquadDataset)
        assert isinstance(tokenizer, SquadTokenizer)
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.val_split = val_split
        self.device = device

        # Preprocess the raw train dataset and perform train/val split
        self.train_dataset, self.val_dataset = None, None
        if self.dataset.raw_train_df is not None:
            train_df = self.dataset.raw_train_df.copy()
            train_df = self._remove_lost_answers(train_df)
            train_df = self._group_answers(train_df)
            self.whole_dataset = SquadTorchDataset(train_df)
            self.train_df, self.val_df = self._train_val_split(train_df, self.val_split)
            self.train_dataset = SquadTorchDataset(self.train_df)
            self.val_dataset = SquadTorchDataset(self.val_df)

        # Preprocess the raw test dataset
        self.test_dataset = None
        if self.dataset.raw_test_df is not None:
            test_df = self.dataset.raw_test_df.copy()
            self.test_df = self._group_answers(test_df)
            self.test_dataset = SquadTorchDataset(self.test_df)

    def _remove_lost_answers(self, df):
        tokenized_contexts = self.tokenizer.tokenize(
            df["context"].tolist(), "context", special=False)
        lost_truncated, lost_dirty = self._lost_answers_indexes(df, tokenized_contexts)
        to_remove = lost_truncated + lost_dirty
        clean_df = df.drop(to_remove)
        assert len(clean_df) == len(df) - len(to_remove), (
            f"Before {len(df)}, " f"after {len(clean_df)}, " f"removed {len(to_remove)}")
        return clean_df

    def _lost_answers_indexes(self, df, tokenized_contexts):
        whole_answers_start = df["answer_start"].tolist()
        whole_answers_end = df["answer_end"].tolist()
        lost_dirty, lost_truncated = [], []
        for i, (c, s, e) in enumerate(
            zip(tokenized_contexts, whole_answers_start, whole_answers_end)):
            mask = (
                torch.tensor(c.attention_mask, device=self.device).bool()
                & ~torch.tensor(c.special_tokens_mask, device=self.device).bool())
            offsets = torch.tensor(c.offsets, device=self.device)[mask]
            start_index = torch.nonzero(offsets[:, 0] == s)
            end_index = torch.nonzero(offsets[:, 1] == e)
            if len(start_index) == 0 or len(end_index) == 0:
                if s > offsets[-1, 0] or e > offsets[-1, 1]:
                    lost_truncated.append(i)
                else:
                    lost_dirty.append(i)

        return lost_truncated, lost_dirty

    def _group_answers(self, df):
        if "answer" not in df.columns:
            return df

        return (
            df.groupby(["question_id", "question", "title", "context_id", "context"])
            .agg({"answer": list, "answer_start": list, "answer_end": list})
            .reset_index()
        )

    def _train_val_split(self, df, val_split):
        val_size = round(df.shape[0] * val_split)
        val_actual_size = 0
        val_keys = []
        for t, n in df["title"].value_counts().to_dict().items():
            if val_actual_size + n > val_size:
                break
            val_keys.append(t)
            val_actual_size += n

        # Build the train and validation DataFrames
        train_df = df[~df["title"].isin(val_keys)].reset_index(drop=True)
        val_df = df[df["title"].isin(val_keys)].reset_index(drop=True)
        return train_df, val_df

recurrent_dm = SquadDataManager(
    squad_dataset, recurrent_tokenizer, val_split=config.VAL_SPLIT, device=DEVICE)

#We will first build the QA baseline model from scratch 
#given below is the absract QA model 
class QAModel(nn.Module):
    IGNORE_LAYERS = []

    def __init__(self):
        super().__init__()

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def state_dict(self):
        st_dict = super().state_dict()
        keys = set(st_dict.keys())
        for l in self.IGNORE_LAYERS:
            for k in st_dict.keys():
                if k.startswith(l):
                    keys.remove(k)
        return collections.OrderedDict({k: v for k, v in st_dict.items() if k in keys})

    def load_state_dict(self, state_dict, strict=False):
        return super().load_state_dict(state_dict, strict=strict)
    
class QABaselineModel(QAModel):
    IGNORE_LAYERS = ["embedding.weight"]
    def __init__(
        self,
        embedding_module,
        hidden_size=100,
        num_recurrent_layers=2,
        bidirectional=False,
        dropout_rate=0.2,
        device="cpu",):
        super().__init__()
        self.embedding = embedding_module
        self.embedding_dimension = embedding_module.weight.shape[-1]
        self.projection = nn.Linear(self.embedding_dimension, hidden_size)
        self.recurrent_module = layer.LSTM(
            hidden_size,
            hidden_size,
            batch_first=True,
            num_layers=num_recurrent_layers,
            bidirectional=bidirectional,
            dropout=dropout_rate,)
        out_dim = hidden_size if not bidirectional else 2 * hidden_size
        self.out_lstm = layer.LSTM(
            out_dim, hidden_size, batch_first=True, bidirectional=bidirectional)
        self.output_layer = layer.QAOutput(
            out_dim,
            1,
            dropout_rate=dropout_rate,
            classifier_bias=True,
            device=device,)
        self.device = device
        self.to(self.device)

    def forward(self, **inputs):
        embedded_questions = self.embedding(inputs["question_ids"])
        embedded_contexts = self.embedding(inputs["context_ids"])

        hidden_questions = self.projection(embedded_questions)
        hidden_contexts = self.projection(embedded_contexts)

        padded_questions, _ = self.recurrent_module(
            hidden_questions, inputs["question_lenghts"])
        padded_contexts, _ = self.recurrent_module(
            hidden_contexts, inputs["context_lenghts"])
        average_questions = padded_questions.sum(dim=1) / inputs[
            "question_lenghts"].view(-1, 1)
        start_input = padded_contexts * average_questions.unsqueeze(1).repeat(
            1, padded_contexts.shape[1], 1)
        end_input, _ = self.out_lstm(start_input, inputs["context_lenghts"])
        return self.output_layer(start_input, end_input, **inputs)

%env WANDB_RUN_GROUP=baseline
baseline_run_name = utils.get_run_name()
baseline_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{baseline_run_name}",
    num_train_epochs=30,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
)

baseline_model = QABaselineModel(embedding_layer, device=DEVICE)
print(f"The baseline model has {baseline_model.count_parameters()} parameters")

baseline_optimizer = optim.Adam(baseline_model.parameters(), lr=1e-3)
baseline_lr_scheduler = transformers.get_constant_schedule(baseline_optimizer)
baseline_trainer = training.SquadTrainer(
    model=baseline_model,
    args=baseline_args(run_name=baseline_run_name),
    data_collator=recurrent_dm.tokenizer,
    train_dataset=recurrent_dm.train_dataset,
    eval_dataset=recurrent_dm.val_dataset,
    optimizers=(baseline_optimizer, baseline_lr_scheduler),
)

wandb.init(project="QA", entity="nkaranam")
baseline_trainer.train()
baseline_test_output = baseline_trainer.predict(recurrent_dm.test_dataset)
baseline_test_output.metrics
baseline_answers_path = "results/answers/baseline.json"
utils.save_answers(baseline_answers_path, baseline_test_output.predictions[-1])
wandb.save(baseline_answers_path);
wandb.finish()

#Now we will run the BIDAF model and store the results
%env WANDB_RUN_GROUP=bidaf
bidaf_run_name = utils.get_run_name()
bidaf_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bidaf_run_name}",
    num_train_epochs=18, #18
    per_device_train_batch_size=60,
    per_device_eval_batch_size=60,
)

class QABiDAFModel(QAModel):
    IGNORE_LAYERS = ["word_embedding.weight"]
    def __init__(
        self,
        embedding_module,
        hidden_size=100,
        highway_depth=2,
        dropout_rate=0.2,
        device="cpu",):
        super().__init__()
        self.word_embedding = embedding_module
        self.word_embedding_dimension = embedding_module.weight.shape[-1]
        self.projection = nn.Linear(
            self.word_embedding_dimension, hidden_size, bias=False)
        self.highway_depth = highway_depth
        self.highway = layer_utils.get_highway(
            self.highway_depth, hidden_size, device=device)
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(self.dropout_rate)
        self.contextual_embedding = layer.LSTM(
            hidden_size,
            hidden_size,
            batch_first=True,
            num_layers=1,
            bidirectional=True,)
        self.attention = layer.AttentionFlow(2 * hidden_size, device=device)
        self.modeling_layer = layer.LSTM(
            8 * hidden_size,
            hidden_size,
            batch_first=True,
            bidirectional=True,
            num_layers=2,
            dropout=self.dropout_rate,)
        self.out_lstm = layer.LSTM(
            2 * hidden_size,
            hidden_size,
            batch_first=True,
            bidirectional=True,)
        self.output_layer = layer.QAOutput(
            10 * hidden_size,
            1,
            dropout_rate=self.dropout_rate,
            classifier_bias=False,
            device=device,)
        self.device = device
        self.to(self.device)

    def forward(self, **inputs):
        # Extract masks and lenghts from the inputs
        questions_mask = inputs["question_attention_mask"]
        contexts_mask = inputs["context_attention_mask"]
        questions_length = inputs["question_lenghts"]
        contexts_length = inputs["context_lenghts"]

        embedded_questions = self.dropout(self.word_embedding(inputs["question_ids"]))
        embedded_contexts = self.dropout(self.word_embedding(inputs["context_ids"]))

        hidden_questions = self.projection(embedded_questions)
        hidden_contexts = self.projection(embedded_contexts)

        highway_questions = self.highway(hidden_questions)
        highway_contexts = self.highway(hidden_contexts)

        contextual_questions, _ = self.contextual_embedding(
            highway_questions, questions_length)
        contextual_contexts, _ = self.contextual_embedding(
            highway_contexts, contexts_length)

        query_aware_contexts = self.attention(
            contextual_questions, contextual_contexts, questions_mask, contexts_mask)
        modeling, _ = self.modeling_layer(query_aware_contexts, contexts_length)
        m2, _ = self.out_lstm(modeling, contexts_length)
        return self.output_layer(
            torch.cat([query_aware_contexts, modeling], dim=-1),
            torch.cat([query_aware_contexts, m2], dim=-1),
            **inputs,)
    
#Initialize the model
bidaf_model = QABiDAFModel(embedding_layer, device=DEVICE)
print(f"The BiDAF model has {bidaf_model.count_parameters()} parameters")
bidaf_optimizer = optim.Adadelta(bidaf_model.parameters(), lr=0.5)
bidaf_lr_scheduler = transformers.get_constant_schedule(bidaf_optimizer)
bidaf_trainer = training.SquadTrainer(
    model=bidaf_model,
    args=bidaf_args(run_name=bidaf_run_name),
    data_collator=recurrent_dm.tokenizer,
    train_dataset=recurrent_dm.train_dataset,
    eval_dataset=recurrent_dm.val_dataset,
    optimizers=(bidaf_optimizer, bidaf_lr_scheduler),)

import wandb
wandb.init(project="QA", entity="nkaranam")
bidaf_trainer.train()
bidaf_test_output = bidaf_trainer.predict(recurrent_dm.test_dataset)
bidaf_test_output.metrics

bidaf_answers_path = "results/answers/bidaf.json"
utils.save_answers(bidaf_answers_path, bidaf_test_output.predictions[-1])
wandb.save(bidaf_answers_path);
wandb.finish()

#Now that we are done with playing around with the baseline models,
#we will now use Transformers-based modules

#We need to define a new tokenizer for the transformer based module because
# the tokenization is different for transformer models
class TransformerSquadTokenizer(SquadTokenizer):
    def __init__(self, tokenizer, device="cpu"):
        super().__init__(device=device)
        assert isinstance(tokenizer, Tokenizer) or isinstance(tokenizer, BaseTokenizer)
        self.tokenizer = tokenizer

    def select_tokenizer(self, entity=None):
        return self.tokenizer

    def __call__(self, inputs):
        zipped_inputs = tuple(zip(*inputs))
        if len(zipped_inputs) > 3:
            (indexes, questions, contexts, answers_start, answers_end) = zipped_inputs
            testing = False
        else:
            (indexes, questions, contexts) = zipped_inputs
            testing = True

        tokenized = self.tokenize(list(zip(questions, contexts)), special=True)
        attr = list(zip(*[self.ATTRGETTER(e) for e in tokenized]))

        # Create the batch dictionary with encoding info
        batch = {
            "context_ids": torch.tensor(
                attr[self.ENCODING_ATTR_ID["ids"]], device=self.device),
            "context_type_ids": torch.tensor(
                attr[self.ENCODING_ATTR_ID["type_ids"]], device=self.device),
            "attention_mask": torch.tensor(
                attr[self.ENCODING_ATTR_ID["attention_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "special_tokens_mask": torch.tensor(
                attr[self.ENCODING_ATTR_ID["special_tokens_mask"]],
                dtype=torch.bool,
                device=self.device,),
            "offsets": torch.tensor(
                attr[self.ENCODING_ATTR_ID["offsets"]], device=self.device),
            "indexes": torch.tensor(indexes, dtype=torch.long, device=self.device),
        }

        # Add custom info to the batch dict
        batch["context_attention_mask"] = (
            batch["context_type_ids"].bool() & ~batch["special_tokens_mask"])
        batch["context_offsets"] = torch.where(
            batch["context_attention_mask"].unsqueeze(-1).repeat(1, 1, 2),
            batch["offsets"],
            -100,)
        (
            batch["subword_start_mask"],
            batch["subword_end_mask"],
        ) = self.find_subword_indexes(attr[self.ENCODING_ATTR_ID["word_ids"]])

        if not testing:
            batch["answers"] = self.find_tokenized_answer_indexes(
                batch["context_offsets"], answers_start, answers_end)
        return batch

transformer_tokenizer = get_transformer_tokenizer(
    config.BERT_VOCAB_PATH, config.MAX_BERT_TOKENS, device=DEVICE)
transformer_dm = SquadDataManager(
    squad_dataset, transformer_tokenizer, val_split=config.VAL_SPLIT, device=DEVICE)

#We will now run the BERT Model
class QABertModel(QAModel):
    
    BERT_OUTPUT_SIZE = 768
    MODEL_TYPE = "bert-base-uncased"

    def __init__(self, dropout_rate=0.2, device="cpu"):
        super().__init__()

        # BERT model
        self.bert_model = self.get_model()
        self.out_lstm = layer.LSTM(
            self.BERT_OUTPUT_SIZE,
            self.BERT_OUTPUT_SIZE,
            batch_first=True,
            bidirectional=False,)
        self.output_layer = layer.QAOutput(
            self.BERT_OUTPUT_SIZE,
            1,
            dropout_rate=dropout_rate,
            classifier_bias=False,
            device=device,)
        self.device = device
        self.to(self.device)

    def get_model(self):
        return transformers.BertModel.from_pretrained(self.MODEL_TYPE)

    def get_model_inputs(self, **inputs):
        return {
            "input_ids": inputs["context_ids"],
            "token_type_ids": inputs["context_type_ids"],
            "attention_mask": inputs["attention_mask"],
        }

    def forward(self, **inputs):
        bert_inputs = self.get_model_inputs(**inputs)
        bert_outputs = self.bert_model(**bert_inputs)[0]
        end_input, _ = self.out_lstm(
            bert_outputs,
            torch.tensor(bert_outputs.shape[1], device=self.device).repeat(
                bert_outputs.shape[0]),)
        outputs = self.output_layer(bert_outputs, end_input, **inputs)
        return outputs
    
%env WANDB_RUN_GROUP=bert
bert_run_name = utils.get_run_name()
bert_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{bert_run_name}",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)
bert_model = QABertModel(device=DEVICE)
bert_optimizer = optim.Adam(bert_model.parameters(), lr=5e-5)
bert_lr_scheduler = transformers.get_constant_schedule(bert_optimizer)
bert_trainer = training.SquadTrainer(
    model=bert_model,
    args=bert_args(run_name=bert_run_name),
    data_collator=transformer_dm.tokenizer,
    train_dataset=transformer_dm.train_dataset,
    eval_dataset=transformer_dm.val_dataset,
    optimizers=(bert_optimizer, bert_lr_scheduler),)

import wandb
wandb.init(project="QA", entity="nkaranam")
bert_trainer.train()
bert_test_output = bert_trainer.predict(transformer_dm.test_dataset)
bert_test_output.metrics
bert_answers_path = "results/answers/bert.json"
utils.save_answers(bert_answers_path, bert_test_output.predictions[-1])
wandb.save(bert_answers_path);
wandb.finish()

#Now we will run the DistilBERT, Given below is the class for distilBERT
class QADistilBertModel(QABertModel):
    MODEL_TYPE = "distilbert-base-uncased"
    def __init__(self, dropout_rate=0.2, device="cpu"):
        super().__init__(dropout_rate=dropout_rate, device=device)
    def get_model(self):
        return transformers.DistilBertModel.from_pretrained(self.MODEL_TYPE)

    def get_model_inputs(self, **inputs):
        return {
            "input_ids": inputs["context_ids"],
            "attention_mask": inputs["attention_mask"],
        }

%env WANDB_RUN_GROUP=distilbert
distilbert_run_name = utils.get_run_name()
distilbert_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{distilbert_run_name}",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
)
distilbert_model = QADistilBertModel(device=DEVICE)
distilbert_optimizer = optim.Adam(distilbert_model.parameters(), lr=5e-5)
distilbert_lr_scheduler = transformers.get_constant_schedule(distilbert_optimizer)
distilbert_trainer = training.SquadTrainer(
    model=distilbert_model,
    args=distilbert_args(run_name=distilbert_run_name),
    data_collator=transformer_dm.tokenizer,
    train_dataset=transformer_dm.train_dataset,
    eval_dataset=transformer_dm.val_dataset,
    optimizers=(distilbert_optimizer, distilbert_lr_scheduler),
)
import wandb

wandb.init(project="QA", entity="nkaranam")
distilbert_trainer.train()
distilbert_test_output = distilbert_trainer.predict(transformer_dm.test_dataset)
distilbert_test_output.metrics
distilbert_answers_path = "results/answers/distilbert.json"
utils.save_answers(distilbert_answers_path, distilbert_test_output.predictions[-1])
wandb.save(distilbert_answers_path);
wandb.finish()

#Now we will use ELECTRA as a transformer
class QAElectraModel(QABertModel):
    MODEL_TYPE = "google/electra-base-discriminator"
    def __init__(self, dropout_rate=0.2, device="cpu"):
        super().__init__(dropout_rate=dropout_rate, device=device)
    def get_model(self):
        return transformers.ElectraModel.from_pretrained(self.MODEL_TYPE)
    
%env WANDB_RUN_GROUP=electra
electra_run_name = utils.get_run_name()
electra_args = partial(
    TRAINER_ARGS,
    output_dir=f"./checkpoints/{os.getenv('WANDB_RUN_GROUP')}/{electra_run_name}",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,)

electra_model = QAElectraModel(device=DEVICE)
electra_optimizer = optim.Adam(electra_model.parameters(), lr=5e-5)
electra_lr_scheduler = transformers.get_constant_schedule(electra_optimizer)
electra_trainer = training.SquadTrainer(
    model=electra_model,
    args=electra_args(run_name=electra_run_name),
    data_collator=transformer_dm.tokenizer,
    train_dataset=transformer_dm.train_dataset,
    eval_dataset=transformer_dm.val_dataset,
    optimizers=(electra_optimizer, electra_lr_scheduler),)

import wandb
wandb.init(project="QA", entity="nkaranam")
electra_trainer.train()
electra_test_output = electra_trainer.predict(transformer_dm.test_dataset)
electra_test_output.metrics
electra_answers_path = "results/answers/electra.json"
utils.save_answers(electra_answers_path, electra_test_output.predictions[-1])
wandb.save(electra_answers_path);
wandb.finish()




In [None]:
#Error analysis

import sys
sys.path.insert(0, "src")
from transformers.trainer_utils import set_seed
import config
%load_ext autoreload
%autoreload 2
set_seed(config.RANDOM_SEED)
with open('results/wrong/baseline.json') as f:
    baseline_errors = json.load(f)
with open('results/wrong/bidaf.json') as f:
    bidaf_errors = json.load(f)
with open('results/wrong/bert.json') as f:
    bert_errors = json.load(f)
with open('results/wrong/distilbert.json') as f:
    distilbert_errors = json.load(f)
with open('results/wrong/electra.json') as f:
    electra_errors = json.load(f)

print(f"The Baseline model makes {len(baseline_errors)} errors")
print(f"The BiDAF model makes {len(bidaf_errors)} errors")
print(f"The BERT model makes {len(bert_errors)} errors")
print(f"The DistilBERT model makes {len(distilbert_errors)} errors")
print(f"The ELECTRA model makes {len(electra_errors)} errors")

for e in random_best_common_errors:
    context = electra_errors[e]["context"]
    question = electra_errors[e]["question"]
    answers = electra_errors[e]["answers"]
    electra_pred = electra_errors[e]["prediction"]
    bidaf_pred = bidaf_errors[e]["prediction"]
    print(f"Context: {context}")
    print(f"Question: {question}")
    print(f"Answers: {answers}")
    print(f"Predictions: [ELECTRA] {electra_pred} [BiDAF] {bidaf_pred}")
    print()