In [None]:
import pandas as pd
import numpy as np
import torchtext
import torch
from torch import nn
import json, re, unicodedata, string, typing, time
import torch.nn.functional as F
import spacy
from collections import Counter
from tqdm.auto import tqdm
import pickle
from nltk import word_tokenize
import ipynb.fs
nlp = spacy.load('en_core_web_sm')
from .defs.model import *
load = False

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('device: ' + str(device))

In [None]:
class SquadDataset:    
    def __init__(self, data, batch_size):
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
    def get_span(self, text):
        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]
        return span

    def __len__(self):
        return len(self.data)
    
    def __iter__(self):
        for batch in self.data:                
            spans = []
            context_text = []
            answer_text = []
            
            max_context_len = max([len(ctx[0]) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            context_fts = torch.LongTensor(len(batch),max_context_len,2).fill_(0)
            for ctx in batch.context:
                context_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx[0])] = torch.LongTensor(ctx[0])
                context_fts[i,:len(ctx[1])] = torch.LongTensor(ctx[1])
            
            max_question_len = max([len(ques[0]) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i,: len(ques[0])] = torch.LongTensor(ques[0])
                
            
            label = torch.LongTensor(list(batch.label_idx))
            context_mask = torch.eq(padded_context, 1)
            question_mask = torch.eq(padded_question, 1)
            
            ids = list(batch.id)  
            
            yield (padded_context, padded_question, context_mask, 
                   question_mask, label, context_text, answer_text, ids, context_fts)

In [None]:
train_df = pd.read_pickle('/scratch/arjunth2001/drqa/drqatrain.pkl')#.head(500)
valid_df = pd.read_pickle('/scratch/arjunth2001/drqa/drqavalid.pkl')#.head(500)
train_dataset = SquadDataset(train_df, 32)
valid_dataset = SquadDataset(valid_df, 32)

In [None]:
HIDDEN_DIM = 128
EMB_DIM = 300
NUM_LAYERS = 3
NUM_DIRECTIONS = 2
DROPOUT = 0.3
LR = 1e-5
model = DocumentReader(HIDDEN_DIM,
                       EMB_DIM, 
                       NUM_LAYERS, 
                       NUM_DIRECTIONS, 
                       DROPOUT, 
                       device).to(device)
optimizer = torch.optim.Adamax(model.parameters(),lr=LR)
model_path="models/model.pt"
if load :
    loaded_state_dict = torch.load(model_path,  map_location=device)
    model.load_state_dict(loaded_state_dict)

In [None]:
def train(model, train_dataset):    
    train_loss = 0.
    model.train()
    for batch in tqdm(train_dataset):
        context, question, context_mask, question_mask, label, ctx, ans, ids , fts = batch
        context, context_mask, question, question_mask, label, fts = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device), fts.to(device)
        
        start_pred, end_pred = model(context, question, context_mask, question_mask,fts)
        start_label, end_label = label[:,0], label[:,1]
        
        loss = F.cross_entropy(start_pred, start_label) + F.cross_entropy(end_pred, end_label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    return train_loss/len(train_dataset)

In [None]:
def valid(model, valid_dataset):
    valid_loss = 0.
    f1, em = 0., 0.
    model.eval()
    predictions = {}
    for batch in tqdm(valid_dataset):
        context, question, context_mask, question_mask, label, context_text, answers, ids, fts = batch
        context, context_mask, question, question_mask, label, fts = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device), fts.to(device)
        with torch.no_grad():
            p1, p2 = model(context, question, context_mask, question_mask, fts)
            y1, y2 = label[:,0], label[:,1]
            loss = F.cross_entropy(p1, y1) + F.cross_entropy(p2, y2)
            valid_loss += loss.item()
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred     
    em, f1 = evaluate(predictions)            
    return valid_loss/len(valid_dataset), em, f1

In [None]:
def evaluate(predictions):
    with open('../data/dev-v2.0.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
        
    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1

In [None]:
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

In [None]:
import pickle
with open('/scratch/arjunth2001/drqa/drqastoi.pickle','rb') as handle:
    word2idx = pickle.load(handle)

In [None]:
idx2word = {v : k for k,v in word2idx.items()}

In [None]:
train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 50
best_loss = np.inf
best_loss = np.inf
best_epoch = 0
for epoch in range(epochs):
    print(f"Epoch {epoch+1}") 
    train_loss = train(model, train_dataset)
    valid_loss, em, f1 = valid(model, valid_dataset)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    if (valid_loss < best_loss):
        torch.save(model.state_dict(), model_path)
        best_loss = valid_loss
        best_epoch = epoch+1
    print("====================================================================================")
    print(f"Epoch train loss : {train_loss}  valid loss: {valid_loss} EM: {em} F1: {f1} Best Epoch: {best_epoch}")
    print("====================================================================================")