Natural Question Dataset:
Official Link:
This dataset contains really long context/passages and a single question based on the passage. The answer to the questions are subparts of the passage given in the form of <start_token> and <end_token>.
What makes this dataset really challenging is the passage and answer length.

I chose this dataset because it will be interesting to work with transformers with such long sequnces. When I tried this dataset with a seq2seq I kept getting OOM. The reason being exterme lenghts of contexts and their answers. The context are entire wikipedia pages in HTML format so there length is typically over 2000.

I started experimenting with original dataset. I kept watering down the task. First reducing the number of examples, then answer length, batch_size but I kept getting OOM. Finally I restricted context lengths to 1000.

The problem was in the sequence length of context. If I restrict them to 1000, the model has good enough memory to work with. No of samples has not much effect. And batch size does. I will remove this 1000 length restriction with transformers.

In [1]:
import torch
import json
from torchtext import data
from itertools import chain
import torch.nn as nn
import torch.optim as optim
import time
from torch.nn import Embedding

In [2]:
%%bash
file=/content/train.jsonl
if [ ! -f "$file" ]; then
  # Since the dataset itself is huge, we will make train and test set from the original train file itself
  # A simple wget wont be able to get the natural questions dataset present at https://ai.google.com/research/NaturalQuestions/download
  # Use the advice given in https://www.kaggle.com/c/deepfake-detection-challenge/discussion/121194 
  # the download link is https://storage.cloud.google.com/natural_questions/v1.0-simplified/simplified-nq-train.jsonl.gz 
  
  # put curl command here

  zcat /content/train.jsonl.gz > /content/train.jsonl
  
fi

In [26]:
# In the beginning restrciting total number of examples to 10k. 
total_limit = 10000

file = "/content/train.jsonl"
f = open(file)

total = 0
examples = []
for i,line in enumerate(f):
    if total >total_limit:
        break
    ak = json.loads(line)
    context = ak['document_text']
    question = ak['question_text']
    start = ak['annotations'][0]['long_answer']['start_token']
    end = ak['annotations'][0]['long_answer']['end_token']
    try:
        assert start < end and end-start > 200 and len(context.split(" ")) <1000
    except AssertionError:
        continue
    answer = " ".join(context.split(" ")[start:end])
    examples.append([context,question,answer])
    total += 1
f.close()

# Will do the architecture with seperate encoders for context and question and a decoder for answer
context = data.Field(sequential=True, tokenize='spacy', init_token='<sos>', eos_token='<eos>')
question = data.Field(sequential=True, tokenize='spacy', init_token='<sos>', eos_token='<eos>')
answer = data.Field(sequential=True, tokenize='spacy', init_token='<sos>', eos_token='<eos>')

fields = [('context', context), ('question', question), ('answer', answer)]

Examples = [data.Example.fromlist([i[0], i[1], i[2]], fields) for i in examples]
Dataset = data.Dataset(Examples, fields)

train_dataset,valid_dataset = Dataset.split(split_ratio=[0.85,0.15])

context.build_vocab(train_dataset,min_freq=2,max_size = 20000,vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_)
question.build_vocab(train_dataset,min_freq=2,max_size = 5000,vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_)
answer.build_vocab(train_dataset,min_freq=2,max_size = 10000,vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_)


BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator = data.BucketIterator.splits((train_dataset, valid_dataset), batch_size=32,
                                                            sort_key=lambda x: len(x.context),
                                                            sort_within_batch=True,device=device)

In [27]:
print(Examples[0])
print(len(Examples))

<torchtext.data.example.Example object at 0x7f84afaf1048>
247


In [28]:
class Encoder(nn.Module):
  def __init__(self,input_dim,emb_dim,hid_dim,n_layers,dropout,bidirectional):
    super().__init__()
    self.hid_dim = hid_dim
    
    self.embedding = nn.Embedding(input_dim,emb_dim)

    self.rnn = nn.LSTM(input_size=emb_dim,hidden_size = hid_dim,num_layers= n_layers,dropout= dropout,bidirectional = bidirectional)

    self.dropout = nn.Dropout(dropout)
  
  def forward(self,input,hidden=None,cell_state=None):
    
    embedded = self.dropout(self.embedding(input))

    if not hidden == None:
      outputs, (hidden,cell_state) = self.rnn(embedded,(hidden,cell_state))
    else:
      outputs, (hidden,cell_state) = self.rnn(embedded)

    return hidden,cell_state

In [29]:
class Decoder(nn.Module):
  def __init__(self,output_dim,emd_dim,hid_dim,n_layers,bidirectional,dropout):
    super().__init__()

    self.embedded = nn.Embedding(output_dim,emb_dim)

    self.rnn = nn.LSTM(input_size=emb_dim,hidden_size=hid_dim,num_layers=n_layers,bidirectional=bidirectional,dropout=dropout)

    self.dropout = nn.Dropout(dropout)

    no_of_directions = 2 if bidirectional else 1

    self.fc_out = nn.Linear(no_of_directions*hid_dim,output_dim)

  def forward(self,input,hidden,cell_state):
    input = input.unsqueeze(0)

    input = self.dropout(self.embedded(input))

    output, (hidden,cell_state) = self.rnn(input,(hidden,cell_state))

    output = output.squeeze(0)
    
    prediction = self.fc_out(output)

    return prediction, hidden , cell_state


In [30]:
import random
class Seq2seq(nn.Module):
  def __init__(self,context_dim,question_dim,answer_dim,emd_dim,hid_dim,n_layers,bidirectional,dropout):
    super().__init__()

    self.context_encoder = Encoder(context_dim,emd_dim,hid_dim,n_layers,dropout,bidirectional)
    self.question_encoder = Encoder(question_dim,emd_dim,hid_dim,n_layers,dropout,bidirectional)
    self.answer_decoder = Decoder(answer_dim,emd_dim,hid_dim,n_layers,bidirectional,dropout)

    self.answer_dim = answer_dim

  def forward(self,context,question,answer,teacher_forcing =0.5):

    hidden,cell_state = self.context_encoder(context)

    hidden,cell_state = self.question_encoder(question,hidden,cell_state)

    answer_len = len(answer)
    batch_size = answer.shape[1]

    outputs = torch.zeros(answer_len,batch_size,self.answer_dim).to(device)

    for i,j in enumerate(range(answer_len)):
      k = answer[j]
      if i != 0:
        k = prediction.argmax(1) if random.random() < teacher_forcing else k
      prediction, hidden, cell_state = self.answer_decoder(k,hidden,cell_state)
      outputs[j] = prediction

    return outputs

In [31]:
context_dim = len(context.vocab)
question_dim = len(question.vocab)
answer_dim = len(answer.vocab)
emb_dim = 100
hid_dim = 100
n_layers = 1
bidirectional = False
dropout = 0.5

model = Seq2seq(context_dim,question_dim,answer_dim,emb_dim,hid_dim,n_layers,bidirectional,dropout).to(device)

def init_weights(m):
    for name, param in m.named_parameters():
      if not isinstance(m, Embedding):
        nn.init.normal_(param.data, mean=0, std=0.01)
        
model.apply(init_weights)

optimizer = optim.Adam(model.parameters())


TRG_PAD_IDX = answer.vocab.stoi[answer.pad_token]

criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

  "num_layers={}".format(dropout, num_layers))


In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 1,835,590 trainable parameters


In [33]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        context_ = batch.context
        question_ = batch.question
        answer_ = batch.answer
        
        optimizer.zero_grad()
        
        output = model(context_, question_,answer_)
        
        trg = answer_
        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [34]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            context_ = batch.context
            question_ = batch.question
            answer_ = batch.answer
        
            output = model(context_, question_,answer_,0) #turn off teacher forcing

            trg = answer_
            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [35]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [36]:
import math
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut2-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 0m 34s
	Train Loss: 8.356 | Train PPL: 4256.789
	 Val. Loss: 8.340 |  Val. PPL: 4188.036
Epoch: 02 | Time: 0m 34s
	Train Loss: 8.319 | Train PPL: 4100.958
	 Val. Loss: 8.248 |  Val. PPL: 3821.181
Epoch: 03 | Time: 0m 34s
	Train Loss: 7.641 | Train PPL: 2082.310
	 Val. Loss: 6.339 |  Val. PPL: 566.443
Epoch: 04 | Time: 0m 33s
	Train Loss: 6.034 | Train PPL: 417.516
	 Val. Loss: 5.093 |  Val. PPL: 162.831
Epoch: 05 | Time: 0m 34s
	Train Loss: 5.121 | Train PPL: 167.522
	 Val. Loss: 4.200 |  Val. PPL:  66.661
Epoch: 06 | Time: 0m 34s
	Train Loss: 4.584 | Train PPL:  97.898
	 Val. Loss: 3.716 |  Val. PPL:  41.115
Epoch: 07 | Time: 0m 34s
	Train Loss: 4.385 | Train PPL:  80.252
	 Val. Loss: 3.560 |  Val. PPL:  35.163
Epoch: 08 | Time: 0m 34s
	Train Loss: 4.373 | Train PPL:  79.251
	 Val. Loss: 3.541 |  Val. PPL:  34.491
Epoch: 09 | Time: 0m 34s
	Train Loss: 4.376 | Train PPL:  79.549
	 Val. Loss: 3.526 |  Val. PPL:  33.981
Epoch: 10 | Time: 0m 33s
	Train Loss: 4.363 | Trai