https://github.com/kushalj001/pytorch-question-answering/blob/master/1.%20DrQA.ipynb

https://github.com/Kprerak-iisc/data-science/blob/b8b8e402e4a0ba7aa9c946f56c20e1d0251c158b/NLP/QA/BiDAF.ipynb

In [1]:
import pandas as pd
import numpy as np
import torchtext
import torch
from torch import nn
import json, re, unicodedata, string, typing, time
import torch.nn.functional as F
import spacy
from collections import Counter
import pickle
import Bert_Question_Answering_preprocessing as BQ
from nltk import word_tokenize
nlp = spacy.load('en')
%load_ext autoreload
%autoreload 2

In [2]:
train_data = BQ.load_json(r'C:\Users\abc\jupyter\pytorch\Bert\squad_train.json')
valid_data = BQ.load_json(r'C:\Users\abc\jupyter\pytorch\Bert\squad_dev.json')

train_list = BQ.parse_data(train_data)
valid_list = BQ.parse_data(valid_data)

print('Train list len : ', len(train_list))
print('Valid list lne : ', len(valid_list))

train_df = pd.DataFrame(train_list)
valid_df = pd.DataFrame(valid_list)

Length of data:  442
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  Beyoncé
Length of data:  35
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  Normans
Train list len :  86821
Valid list lne :  20302


In [3]:
def normalize_spaces(text):
    """
    정규화 과정
    """
    text = re.sub(r'\s', ' ',text)
    return text

train_df.context = train_df.context.apply(normalize_spaces)
valid_df.context = valid_df.context.apply(normalize_spaces)
                  

In [4]:
train_df.head()

Unnamed: 0,id,context,question,label,answer
0,56be85543aeaaa14008c9063,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce start becoming popular?,"[269, 286]",in the late 1990s
1,56be85543aeaaa14008c9065,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What areas did Beyonce compete in when she was...,"[207, 226]",singing and dancing
2,56be85543aeaaa14008c9066,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce leave Destiny's Child and bec...,"[526, 530]",2003
3,56bf6b0f3aeaaa14008c9601,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what city and state did Beyonce grow up?,"[166, 180]","Houston, Texas"
4,56bf6b0f3aeaaa14008c9602,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In which decade did Beyonce become famous?,"[276, 286]",late 1990s


In [5]:
%time vocab_text = BQ.gather_text_for_vocab([train_df, valid_df])
print("Number of sentences in dataset: ", len(vocab_text))

Wall time: 350 ms
Number of sentences in dataset:  112776


In [6]:
%time word2idx, idx2word, word_vocab = BQ.build_word_vocab(vocab_text)

raw-vocab: 108286
glove-vocab: 108286
vocab-length: 108288
word2idx-length: 108288
Wall time: 23.8 s


In [7]:
%time train_df['context_ids'] = train_df.context.apply(BQ.context_to_ids, word2idx=word2idx)
%time valid_df['context_ids'] = valid_df.context.apply(BQ.context_to_ids, word2idx=word2idx)

%time train_df['question_ids'] = train_df.question.apply(BQ.question_to_ids, word2idx = word2idx)
%time valid_df['question_ids'] = valid_df.question.apply(BQ.question_to_ids, word2idx = word2idx)

Wall time: 1min 7s
Wall time: 19.6 s
Wall time: 6.38 s
Wall time: 1.48 s


In [8]:
train_err = BQ.get_error_indices(train_df, idx2word)
valid_err = BQ.get_error_indices(valid_df, idx2word)

train_df.drop(train_err, inplace=True)
valid_df.drop(valid_err, inplace=True)

Error indices: 986
Error indices: 208


In [9]:
# 시작과 끝 위치 반환
train_label_idx = train_df.apply(BQ.index_answer, axis=1, idx2word=idx2word)
valid_label_idx = valid_df.apply(BQ.index_answer, axis=1, idx2word=idx2word)

train_df['label_idx'] = train_label_idx
valid_df['label_idx'] = valid_label_idx

In [10]:
import pickle
with open('drqastoi.pickle','wb') as handle:
    pickle.dump(word2idx, handle)
    
train_df.to_pickle('drqatrain.pkl')
valid_df.to_pickle('drqavalid.pkl')

In [11]:
train_df = pd.read_pickle('drqatrain.pkl')
valid_df = pd.read_pickle('drqavalid.pkl')

In [12]:
class SquadDataset:
    """
    -데이터 프레임을 batch로 나눕니다.
    -패딩을 통해 각 배치에 대한 context and question을 동적으로 해당 배치의 최대 크기로 패딩
    -문맥과 질문에 대한 마스크를 계산합니다.
    -컨텍스트에 대한 범위를 계산합니다.
    """
    
    def __init__(self, data, batch_size):
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
        
    def get_span(self, text):
        
        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]

        return span

    def __len__(self):
        return len(self.data)
    
    def __iter__(self):
        """
        데이터 배치를 생성하고 산출합니다.
        
        각 수익은 다음으로 구성됩니다.
        : padded_context : 각 배치에 대한 컨텍스트의 패딩 된 텐서
        : padded_question : 각 배치에 대한 질문의 패딩 된 텐서
        : context_mask & question_mask : 질문 및 컨텍스트에 대한 제로 마스크
        : label : 시작 및 끝 인덱스 wrt context_ids
        : context_text, answer_text : 유효성 검사 중에 측정 항목을 계산하는 데 사용됩니다.
        : context_spans : 문맥 텍스트의 범위
        : ids : 평가에 사용되는 question_ids
        """
        
        for batch in self.data:
                            
            spans = []
            context_text = []
            answer_text = []
            
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            
            for ctx in batch.context:
                context_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i,: len(ques)] = torch.LongTensor(ques)
                
            
            label = torch.LongTensor(list(batch.label_idx))
            context_mask = torch.eq(padded_context, 1)
            question_mask = torch.eq(padded_question, 1)
            
            ids = list(batch.id)  
            
            yield (padded_context, padded_question, context_mask, 
                   question_mask, label, context_text, answer_text, ids)

In [13]:
train_dataset = SquadDataset(train_df, 32)
valid_dataset = SquadDataset(valid_df, 32)
a = next(iter(train_dataset))
a[0].shape, a[1].shape, a[2].shape, a[3].shape, a[4].shape

(torch.Size([32, 215]),
 torch.Size([32, 19]),
 torch.Size([32, 215]),
 torch.Size([32, 19]),
 torch.Size([32, 2]))

In [14]:
def create_glove_matrix():
    """
    glove word 벡터 텍스트 파일을 구문 분석하고 다음과 같은 단어가 포함 된 사전을 반환합니다.
    키와 각각의 사전 훈련 된 단어 벡터를 값으로 사용합니다.
    """
    glove_dict = {}
    with open(r"C:\Users\abc\jupyter\pytorch\Bert\glove.840B.300d.txt", "r", encoding="utf-8") as f:
        for line in f:
            values = line.split(' ')
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            glove_dict[word] = vector

    f.close()
    
    return glove_dict

In [15]:
glove_dict = create_glove_matrix()
# https://www.kaggle.com/takuok/glove840b300dtxt

In [16]:
def create_word_embedding(glove_dict):
    """
    GloVe 어휘에서 일반적으로 사용되는 단어의 가중치 행렬을 만들고
    데이터 셋의 어휘. zero 벡터로 OOV 단어를 초기화합니다.
    """
    weights_matrix = np.zeros((len(word_vocab), 300))
    words_found = 0
    for i, word in enumerate(word_vocab):
        try:
            weights_matrix[i] = glove_dict[word]
            words_found += 1
        except:
            pass
    return weights_matrix, words_found

In [17]:
weights_matrix, words_found = create_word_embedding(glove_dict)

In [18]:
print("Total words found in glove vocab: ", words_found)

Total words found in glove vocab:  89541


In [19]:
np.save('drqaglove_vt.npy',weights_matrix)

In [20]:
class AlignQuestionEmbedding(nn.Module):
    
    def __init__(self, input_dim):        
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, input_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, context, question, question_mask):
        
        # context = [bs, ctx_len, emb_dim]
        # question = [bs, qtn_len, emb_dim]
        # question_mask = [bs, qtn_len]
    
        ctx_ = self.linear(context)
        ctx_ = self.relu(ctx_)
        # ctx_ = [bs, ctx_len, emb_dim]
        
        qtn_ = self.linear(question)
        qtn_ = self.relu(qtn_)
        # qtn_ = [bs, qtn_len, emb_dim]
        
        qtn_transpose = qtn_.permute(0,2,1)
        # qtn_transpose = [bs, emb_dim, qtn_len]
        
        align_scores = torch.bmm(ctx_, qtn_transpose)
        # align_scores = [bs, ctx_len, qtn_len]
        
        qtn_mask = question_mask.unsqueeze(1).expand(align_scores.size())
        # qtn_mask = [bs, 1, qtn_len] => [bs, ctx_len, qtn_len]
        
        # Fills elements of self tensor(align_scores) with value(-float(inf)) where mask is True. 
        # The shape of mask must be broadcastable with the shape of the underlying tensor.
        align_scores = align_scores.masked_fill(qtn_mask == 1, -float('inf'))
        # align_scores = [bs, ctx_len, qtn_len]
        
        align_scores_flat = align_scores.view(-1, question.size(1))
        # align_scores = [bs*ctx_len, qtn_len]
        
        alpha = F.softmax(align_scores_flat, dim=1)
        alpha = alpha.view(-1, context.shape[1], question.shape[1])
        # alpha = [bs, ctx_len, qtn_len]
        
        align_embedding = torch.bmm(alpha, question)
        # align = [bs, ctx_len, emb_dim]
        
        return align_embedding

In [21]:
class StackedBiLSTM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        
        super().__init__()
        
        self.dropout = dropout
        
        self.num_layers = num_layers
        
        self.lstms = nn.ModuleList()
        
        for i in range(self.num_layers):
            
            input_dim = input_dim if i == 0 else hidden_dim * 2
            
            self.lstms.append(nn.LSTM(input_dim, hidden_dim,
                                      batch_first=True, bidirectional=True))
           
    
    def forward(self, x):
        # x = [bs, seq_len, feature_dim]

        outputs = [x]
        for i in range(self.num_layers):

            lstm_input = outputs[-1]
            lstm_out = F.dropout(lstm_input, p=self.dropout)
            lstm_out, (hidden, cell) = self.lstms[i](lstm_input)
           
            outputs.append(lstm_out)

    
        output = torch.cat(outputs[1:], dim=2)
        # [bs, seq_len, num_layers*num_dir*hidden_dim]
        
        output = F.dropout(output, p=self.dropout)
      
        return output

In [22]:
class LinearAttentionLayer(nn.Module):
    
    def __init__(self, input_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, question, question_mask):
        
        # question = [bs, qtn_len, input_dim] = [bs, qtn_len, bi_lstm_hid_dim]
        # question_mask = [bs,  qtn_len]
        
        qtn = question.view(-1, question.shape[-1])
        # qtn = [bs*qtn_len, hid_dim]
        
        attn_scores = self.linear(qtn)
        # attn_scores = [bs*qtn_len, 1]
        
        attn_scores = attn_scores.view(question.shape[0], question.shape[1])
        # attn_scores = [bs, qtn_len]
        
        attn_scores = attn_scores.masked_fill(question_mask == 1, -float('inf'))
        
        alpha = F.softmax(attn_scores, dim=1)
        # alpha = [bs, qtn_len]
        
        return alpha

In [23]:
def weighted_average(x, weights):
    # x = [bs, len, dim]
    # weights = [bs, len]
    
    weights = weights.unsqueeze(1)
    # weights = [bs, 1, len]
    
    w = weights.bmm(x).squeeze(1)
    # w = [bs, 1, dim] => [bs, dim]
    
    return w

In [24]:
class BilinearAttentionLayer(nn.Module):
    
    def __init__(self, context_dim, question_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(question_dim, context_dim)
        
    def forward(self, context, question, context_mask):
        
        # context = [bs, ctx_len, ctx_hid_dim] = [bs, ctx_len, hid_dim*6] = [bs, ctx_len, 768]
        # question = [bs, qtn_hid_dim] = [bs, qtn_len, 768]
        # context_mask = [bs, ctx_len]
        
        qtn_proj = self.linear(question)
        # qtn_proj = [bs, ctx_hid_dim]
        
        qtn_proj = qtn_proj.unsqueeze(2)
        # qtn_proj = [bs, ctx_hid_dim, 1]
        
        scores = context.bmm(qtn_proj)
        # scores = [bs, ctx_len, 1]
        
        scores = scores.squeeze(2)
        # scores = [bs, ctx_len]
        
        scores = scores.masked_fill(context_mask == 1, -float('inf'))
        
        #alpha = F.log_softmax(scores, dim=1)
        # alpha = [bs, ctx_len]
        
        return scores

In [25]:
class DocumentReader(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, num_directions, dropout, device):
        
        super().__init__()
        
        self.device = device
        
        #self.embedding = self.get_glove_embedding()
        
        self.context_bilstm = StackedBiLSTM(embedding_dim * 2, hidden_dim, num_layers, dropout)
        
        self.question_bilstm = StackedBiLSTM(embedding_dim, hidden_dim, num_layers, dropout)
        
        self.glove_embedding = self.get_glove_embedding()
        
        def tune_embedding(grad, words=1000):
            grad[words:] = 0
            return grad
        
        self.glove_embedding.weight.register_hook(tune_embedding)
        
        self.align_embedding = AlignQuestionEmbedding(embedding_dim)
        
        self.linear_attn_question = LinearAttentionLayer(hidden_dim*num_layers*num_directions) 
        
        self.bilinear_attn_start = BilinearAttentionLayer(hidden_dim*num_layers*num_directions, 
                                                          hidden_dim*num_layers*num_directions)
        
        self.bilinear_attn_end = BilinearAttentionLayer(hidden_dim*num_layers*num_directions,
                                                        hidden_dim*num_layers*num_directions)
        
        self.dropout = nn.Dropout(dropout)
   
        
    def get_glove_embedding(self):
        
        weights_matrix = np.load('drqaglove_vt.npy')
        num_embeddings, embedding_dim = weights_matrix.shape
        embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix).to(self.device),freeze=False)

        return embedding
    
    
    def forward(self, context, question, context_mask, question_mask):
       
        # context = [bs, len_c]
        # question = [bs, len_q]
        # context_mask = [bs, len_c]
        # question_mask = [bs, len_q]
        
        
        ctx_embed = self.glove_embedding(context)
        # ctx_embed = [bs, len_c, emb_dim]
        
        ques_embed = self.glove_embedding(question)
        # ques_embed = [bs, len_q, emb_dim]
        

        ctx_embed = self.dropout(ctx_embed)
     
        ques_embed = self.dropout(ques_embed)
             
        align_embed = self.align_embedding(ctx_embed, ques_embed, question_mask)
        # align_embed = [bs, len_c, emb_dim]  
        
        ctx_bilstm_input = torch.cat([ctx_embed, align_embed], dim=2)
        # ctx_bilstm_input = [bs, len_c, emb_dim*2]
                
        ctx_outputs = self.context_bilstm(ctx_bilstm_input)
        # ctx_outputs = [bs, len_c, hid_dim*layers*dir] = [bs, len_c, hid_dim*6]
       
        qtn_outputs = self.question_bilstm(ques_embed)
        # qtn_outputs = [bs, len_q, hid_dim*6]
    
        qtn_weights = self.linear_attn_question(qtn_outputs, question_mask)
        # qtn_weights = [bs, len_q]
            
        qtn_weighted = weighted_average(qtn_outputs, qtn_weights)
        # qtn_weighted = [bs, hid_dim*6]
        
        start_scores = self.bilinear_attn_start(ctx_outputs, qtn_weighted, context_mask)
        # start_scores = [bs, len_c]
         
        end_scores = self.bilinear_attn_end(ctx_outputs, qtn_weighted, context_mask)
        # end_scores = [bs, len_c]
        
      
        return start_scores, end_scores

In [26]:
device = torch.device('cuda')
HIDDEN_DIM = 128
EMB_DIM = 300
NUM_LAYERS = 3
NUM_DIRECTIONS = 2
DROPOUT = 0.3
device = torch.device('cuda')

model = DocumentReader(HIDDEN_DIM,
                       EMB_DIM, 
                       NUM_LAYERS, 
                       NUM_DIRECTIONS, 
                       DROPOUT, 
                       device).to(device)

In [27]:
optimizer = torch.optim.Adamax(model.parameters())

In [28]:
def count_parameters(model):
    '''Returns the number of trainable parameters in the model.'''
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 36,527,549 trainable parameters


In [29]:
def train(model, train_dataset):
    '''
    Trains the model.
    '''
    
    print("Starting training ........")
    
    train_loss = 0.
    batch_count = 0
    
    # put the model in training mode
    model.train()
    
    # iterate through training data
    for batch in train_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch: {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, ctx, ans, ids = batch
        
        # place the tensors on GPU
        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)
        
        # forward pass, get the predictions
        preds = model(context, question, context_mask, question_mask)

        start_pred, end_pred = preds
        
        # separate labels for start and end position
        start_label, end_label = label[:,0], label[:,1]
        
        # calculate loss
        loss = F.cross_entropy(start_pred, start_label) + F.cross_entropy(end_pred, end_label)
        
        # backward pass, calculates the gradients
        loss.backward()
        
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        
        # update the gradients
        optimizer.step()
        
        # zero the gradients to prevent them from accumulating
        optimizer.zero_grad()

        train_loss += loss.item()

    return train_loss/len(train_dataset)

In [30]:
def valid(model, valid_dataset):
    '''
    Performs validation.
    '''
    
    print("Starting validation .........")
   
    valid_loss = 0.

    batch_count = 0
    
    f1, em = 0., 0.
    
    # puts the model in eval mode. Turns off dropout
    model.eval()
    
    predictions = {}
    
    for batch in valid_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, context_text, answers, ids = batch

        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)

        with torch.no_grad():

            preds = model(context, question, context_mask, question_mask)

            p1, p2 = preds

            y1, y2 = label[:,0], label[:,1]

            loss = F.cross_entropy(p1, y1) + F.cross_entropy(p2, y2)

            valid_loss += loss.item()

            
            # get the start and end index positions from the model preds
            
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
            # stack predictions
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred
            
            
            
    em, f1 = evaluate(predictions)            
    return valid_loss/len(valid_dataset), em, f1

In [31]:
def evaluate(predictions):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.
    
    
    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''
    with open(r'C:\Users\abc\jupyter\pytorch\Bert\squad_dev.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
        
    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1

In [32]:
def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    Returns maximum value of metrics for predicition by model against
    multiple ground truths.
    
    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.
                            
    
    '''
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    '''
    Returns f1 score of two strings.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    Returns exact_match_score of two strings.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [35]:
train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    
    start_time = time.time()
    
    train_loss = train(model, train_dataset)
    valid_loss, em, f1 = valid(model, valid_dataset)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("====================================================================================")

Epoch 1
Starting training ........
Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Epoch train loss : 3.185348530269762| Time: 8m 37s
Epoch valid loss: 3.4415279266181265
Epoch EM: 27.97944917038659
Epoch F1: 33.81359561770958
Epoch 2
Starting training ........
Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Epoch train loss : 2.897416961601853| Time: 8m 25s
Epoch valid loss: 3.465092615811688
Epoch EM: 28.518487324180914
Epoch F1: 34.20278672225792
Epoch 3
Starting training ........
Starting batch: 0
Starting batch: 500
Starting batch: 1000
Starting batch: 1500
Starting batch: 2000
Starting batch: 2500
Starting validation .........
Starting batch 0
Starting batch 500
Epoch train loss : 2.686902788695763| Time:

RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR