# Baseline Model

For the baseline we have build a simple transformer that just has as an input context+question and tries to predict an answer

In [1]:
import pandas as pd
import numpy as np
import json

Helper class to covert json to dataframe for easier batch processing

In [2]:
class Squad:
    def __init__(self, input_location):
        self.location = input_location
        file = open(input_location)
        json_file = json.load(file)
        # Save version and data
        self.version = json_file['version']
        self.data = json_file['data']
        
        df_builder = [] # We will store every row of dataframe here
        for sample in self.data:
            title = sample['title'] # Get title
            paragraphs = sample['paragraphs']
            
            for paragraph in paragraphs:
                context = paragraph['context'] # Get context, e.g. a paragraph
                questions = paragraph['qas']
                
                for question in questions:
                    q_id = question['id'] # Question id
                    q_content = question['question'] # Question itself
                    answers = question['answers'] # Possible answers
                    is_impossible = question['is_impossible'] # If it is possible to answer
                    
                    # Build a row of dataframe
                    qas = {
                        'id':q_id,
                        'wiki_title':title,
                        'context':context,
                        'content':q_content,
                        'is_impossible':is_impossible
                    }
                    if is_impossible:
                        qas['answer'] = ""
                        qas['answer_start'] = 0
                        qas['answer_end'] = 0
                    else:
                        answer = answers[0]
                        qas['answer'] = answer['text']
                        qas['answer_start'] = answer['answer_start']
                        qas['answer_end'] = answer['answer_start']+len(answer['text'])
                    df_builder.append(qas) 
        self.df = pd.DataFrame(df_builder)

In [3]:
train_sq = Squad('./data/train-v2.0.json')
test_sq = Squad('./data/dev-v2.0.json')
train_df  = train_sq.df
test_df  = test_sq.df

In [4]:
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [5]:
from torchtext import *
from torchtext.data import *

# Taken from here for easier work with dataframe and torchtext
# https://gist.github.com/notnami/3c4d636f2b79e206b26acfe349f2657a
class DataFrameExampleSet:
    def __init__(self, df, fields):
        self._df = df
        self._fields = fields
        self._fields_dict = {field_name: (field_name, field)
                             for field_name, field in fields.items()
                             if field is not None}

    def __iter__(self):
        for item in tqdm(self._df.itertuples(), total=len(self)):
            example = Example.fromdict(item._asdict(), fields=self._fields_dict)
            yield example

    def __len__(self):
        return len(self._df)

    def shuffle(self, random_state=None):
        self._df = self._df.sample(frac=1.0, random_state=random_state)


class DataFrameDataset(Dataset):
    def __init__(self, df, fields, filter_pred=None):
        examples = DataFrameExampleSet(df, fields)
        super().__init__(examples, fields, filter_pred=filter_pred)


class DataFrameBucketIterator(BucketIterator):
    def data(self):
        if isinstance(self.dataset.examples, DataFrameExampleSet):
            if self.shuffle:
                self.dataset.examples.shuffle()
            dataset = self.dataset
        else:
            dataset = super().data()
        return dataset

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
import torchtext
from typing import *
from torchtext.data import *
from tqdm.notebook import tqdm
from torchtext.data.utils import get_tokenizer
import dill
load = True
# TRG_LEN=15
# CONTEXT_Q_LEN = 400

if load:
    with open("model/CONTEXT.Field","rb") as f:
        CONTEXT=dill.load(f)
    with open("model/QUESTION.Field","rb") as f:
        QUESTION=dill.load(f)
else:
    # Init Fields 
    
    # Here will be context and question
    CONTEXT = torchtext.data.Field(tokenize = get_tokenizer("basic_english"),
                          init_token = '<sos>',
                          eos_token = '<eos>',
                          lower = False,
                          batch_first = False)
    # here the target 
    QUESTION = torchtext.data.Field(tokenize = get_tokenizer("basic_english"), 
                         init_token = '<sos>',
                         eos_token = '<eos>',
                         lower = False,
                         batch_first = False)

START = torchtext.data.Field(sequential=False, is_target=True, use_vocab=False)
END = torchtext.data.Field(sequential=False, is_target=True, use_vocab=False)
# Will store id to later check correctness
ID = torchtext.data.Field(is_target=True, sequential=False)

In [8]:
train_dataset = DataFrameDataset(train_df, fields={'context':CONTEXT,'content':QUESTION, 'id':ID,
                                                   'answer_start':START, 'answer_end':END})
test_dataset = DataFrameDataset(test_df, fields={'context':CONTEXT,'content':QUESTION, 'id':ID,
                                                'answer_start':START, 'answer_end':END})

In [9]:
if load:
    pass
else:
    # Build vocabulary from our data, target will have the same vocab as context + questions
    CONTEXT.build_vocab(train_dataset, min_freq=100)
    QUESTION.build_vocab([''])
    QUESTION.vocab = CONTEXT.vocab
    
    
    with open("model/CONTEXT.Field","wb+")as f:
        dill.dump(CONTEXT,f)
    with open("model/QUESTION.Field","wb+")as f:
        dill.dump(QUESTION,f)
ID.build_vocab(list(train_df.id)+ list(test_df.id))

HBox(children=(FloatProgress(value=0.0, max=130319.0), HTML(value='')))




In [10]:
batch_size = 128
# Create iterators
train_iterator, test_iterator = DataFrameBucketIterator.splits((train_dataset, test_dataset), 
                                    batch_size = batch_size,
                                    device = 'cpu')

In [11]:
class BaselineAttn(nn.Module):
    def __init__(self, hidden_size):
        super(BaselineAttn, self).__init__()
        
        self.linear_context = nn.Linear(hidden_size, hidden_size)
        self.linear_question = nn.Linear(hidden_size, hidden_size)
        
        self.linear_combination = nn.Linear(hidden_size, 1)
    
    def forward(self, context, question):
        lin_context = self.linear_context(context) # [1, batch, hidden_size]
        lin_question = self.linear_question(question) # [question_seq_len, batch, hidden_size]
        
        combined = torch.tanh(lin_context + lin_question) # [question_seq_len, batch, hidden_size]
        lin_combined = self.linear_combination(combined).squeeze(2) # [question_seq_len, batch]
        
        attn = torch.nn.functional.softmax(lin_combined, dim=0) # [question_seq_len, batch]
        return attn

In [12]:
class BaselineModel(nn.Module):
    def __init__(self, context_vocab, emb_size, hidden_size,
                dropout=0.1):
        super(BaselineModel, self).__init__()
        
        self.context_emb = nn.Embedding(context_vocab, emb_size)
        self.question_emb = nn.Embedding(context_vocab, emb_size)
        
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
        self.rnn_context = nn.GRU(emb_size, hidden_size)
        self.rnn_question = nn.GRU(emb_size, hidden_size)
        
        self.attn = BaselineAttn(hidden_size)
        
        self.fc_out = nn.Linear(hidden_size*2, 2)
        
        self._init_params()
    
    def forward(self, context, question):
        context_embedded = self.dropout_1(self.context_emb(context)) # [context_seq_len, batch_size, emb_size]
        question_embedded = self.dropout_2(self.question_emb(question)) # [question_seq_len, batch_size, emb_size]
        
        output_context, hidden_context = self.rnn_context(context_embedded) # [context_seq_len, batch_size, hidden_size]
        output_question, hidden_question = self.rnn_question(question_embedded) # [question_seq_len, batch_size, hidden_size]
        print(hidden_context.shape)
        print(hidden_question.shape)
        
        res = torch.zeros([
            output_context.shape[0],
            output_context.shape[1],
            output_context.shape[2] * 2
        ]).to(device)
        
        for i in range(len(output_context)):
            hp = output_context[i, ...]
            attn = self.attn(hp, output_question)
            attn_applied = torch.bmm(attn.transpose(0,1).unsqueeze(1), output_question.transpose(0,1)).squeeze(1)
            res[i] = torch.cat([hp, attn_applied],dim=1)
        
        logits = self.fc_out(res)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        return start_logits, end_logits
    
    def _init_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)

In [13]:
context_vocab = len(CONTEXT.vocab)
emb_size=256
hidden_size=512

model = BaselineModel(context_vocab, emb_size, hidden_size).to(device)

In [14]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'Model has a total of {count_trainable_parameters(model):,} of trainable parameters')

Model has a total of 8,723,971 of trainable parameters


In [15]:
optimizer = torch.optim.Adam(model.parameters(),lr=2e-5)
loss_func = nn.CrossEntropyLoss()

In [16]:
def train(model, iterator, optimizer, loss_func):
    """
    Runs training loop for whole dataset in iterator
    
    model - model to be trained
    iterator - data loader from which we take source and target
    optimizer - our optimizer
    loss_func - function which will compute loss
    return average loss
    """
    model.train() # Switch to train
    epoch_loss = [] # We will calculate cumulative loss
    
    for i, batch in enumerate(iterator):
        optimizer.zero_grad()
        
        context = batch.context.to(device)
        content = batch.content.to(device)
        start_positions = batch.answer_start.to(device)
        end_positions = batch.answer_end.to(device)
        
        start_logits, end_logits = model(context, content)
        
        start_logits, end_logits = start_logits.transpose(0,1), end_logits.transpose(0,1)
        ignored_index = start_logits.size(1)
        start_positions.clamp_(0, ignored_index)
        end_positions.clamp_(0, ignored_index)

        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
        
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        total_loss = (start_loss + end_loss) / 2
    
        writer.add_scalar(f'Loss/train Epoch {epoch}', total_loss, i)
        
        epoch_loss.append(total_loss.item())
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
    return epoch_loss

In [17]:
def evaluate(model, iterator, loss_func):
    """
    Runs an evaluation loop and returns average loss
    
    model - model to be evaluated
    iterator - data loader with validation set
    loss_func - function which will compute loss
    returns average loss
    """
    model.eval() # Switch to eval
    epoch_loss = 0 # We will calculate cumulative loss
    
    with torch.no_grad():
        to_return = []
        
        for i, batch in enumerate(iterator):
            optimizer.zero_grad()

            context = batch.context.to(device)
            content = batch.content.to(device)
            start_positions = batch.answer_start.to(device)
            end_positions = batch.answer_end.to(device)

            start_logits, end_logits = model(context, content)
            
            start_logits, end_logits = start_logits.transpose(0,1), end_logits.transpose(0,1)
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)

            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            
            start_pred = start_logits.softmax(dim=1).topk(1, dim=1)[1].squeeze().cpu().detach().numpy()
            end_pred = end_logits.softmax(dim=1).topk(1, dim=1)[1].squeeze().cpu().detach().numpy()
            
            to_return.append((start_pred, end_pred, batch.id))
            epoch_loss += total_loss.item()

            optimizer.step()
    return epoch_loss / len(iterator), to_return

## Training

In [18]:
import torch.nn.functional as F
import numpy.random as random 

best_loss = float('inf')
epochs = 4
for epoch in range(epochs):
    train_loss = train(model, train_iterator, optimizer, loss_func)
    
    eval_loss, preds = evaluate(model, test_iterator, loss_func)
    
    print(preds[0])
    # save "best" model
    if best_loss > eval_loss:
        best_loss = eval_loss
        torch.save(model.state_dict(), 'baseline.model')
    print(f"Epoch {epoch}. Train loss: {np.mean(train_loss)}. Eval loss: {eval_loss}")

HBox(children=(FloatProgress(value=0.0, max=130319.0), HTML(value='')))

torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])
torch.Size([1, 128, 512])


KeyboardInterrupt: 

In [None]:
eval_loss, preds = evaluate(model, test_iterator, loss_func)

## Evaluation

In [None]:
# flatten output
predictions = []
labels = []
for i in preds:
    for seq1,seq2, tgt in zip(i[0], i[1], i[2]):
        predictions.append((seq1,seq2))
        labels.append(tgt)

In [None]:
def get_preds(predictions, labels, df):
    # Transform our predictions
    my_preds = {}

    for pred, tgt in zip(predictions, labels):
        start, end = pred
        tg_id = ID.vocab.itos[tgt]
        res = df[df.id == tg_id].context.values[0][start:end]
        my_preds[ID.vocab.itos[tgt]] = res
    return my_preds

In [None]:
my_preds = get_preds(predictions, labels, test_df)

In [None]:
my_preds

In [None]:
from evaluate_answers import *

In [None]:
# For the more representetive results we have taken script that squad owner's have written to check predictions

dataset = test_sq.data
preds = my_preds
na_probs = {k: 0.0 for k in preds}

qid_to_has_ans = make_qid_to_has_ans(dataset) 
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
                                      1.0)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
                                   1.0)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
    has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
    merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
    no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
    merge_eval(out_eval, no_ans_eval, 'NoAns')
print(json.dumps(out_eval, indent=2))

## Some samples of the results

In [None]:
for i in range(5):
    choice = np.random.choice(list(my_preds))
    row = test_df[test_df.id == choice].iloc[0]
    print("Context: ", str(row.context))
    print()
    print("Question: ", str(row.content))
    print()
    if row.is_impossible:
        print("Impossible to answer")
    else:
        print("Answer: ", row.answer)
    print()
    if my_preds[choice]:
        print("Predicted answer: ", my_preds[choice])
    else:
        print("Predicted impossbile to answer")
    print("\n//////////////////// \n")