In [19]:
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch
from transformers import DistilBertForQuestionAnswering
from transformers import default_data_collator
from transformers import AdamW
from tqdm.notebook import tqdm
import datetime
import numpy as np
import collections
import evaluate

In [2]:
dataset = load_dataset("squad")

#lets sample a small dataset
dataset['train'] = dataset['train'].select([i for i in range(7500)])
dataset['validation'] = dataset['validation'].select([i for i in range(750)])

dataset

Downloading builder script:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.67k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 7500
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 750
    })
})

In [10]:
def train_data_preprocess(examples):
    
    """
    generate start and end indexes of answer in context
    """
    
    def find_context_start_end_index(sequence_ids):
        """
        returns the token index in whih context starts and ends
        """
        token_idx = 0
        while sequence_ids[token_idx] != 1:  #means its special tokens or tokens of queston
            token_idx += 1                   # loop only break when context starts in tokens
        context_start_idx = token_idx
    
        while sequence_ids[token_idx] == 1:
            token_idx += 1
        context_end_idx = token_idx - 1
        return context_start_idx,context_end_idx  
    
    
    questions = [q.strip() for q in examples["question"]]
    context = examples["context"]
    answers = examples["answers"]
    
    inputs = tokenizer(
        questions,
        context,
        max_length=512,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,  #returns id of base context
        return_offsets_mapping=True,  # returns (start_index,end_index) of each token
        padding="max_length"
    )


    start_positions = []
    end_positions = []

    
    for i,mapping_idx_pairs in enumerate(inputs['offset_mapping']):
        context_idx = inputs['overflow_to_sample_mapping'][i]
    
        # from main context
        answer = answers[context_idx]
        answer_start_char_idx = answer['answer_start'][0]
        answer_end_char_idx = answer_start_char_idx + len(answer['text'][0])

    
        # now we have to find it in sub contexts
        tokens = inputs['input_ids'][i]
        sequence_ids = inputs.sequence_ids(i)
   
        # finding the context start and end indexes wrt sub context tokens
        context_start_idx,context_end_idx = find_context_start_end_index(sequence_ids)
    
        #if the answer is not fully inside context label it as (0,0)
        # starting and end index of charecter of full context text
        context_start_char_index = mapping_idx_pairs[context_start_idx][0]
        context_end_char_index = mapping_idx_pairs[context_end_idx][1]
    

        #If the answer is not fully inside the context, label is (0, 0)
        if (context_start_char_index > answer_start_char_idx) or (
            context_end_char_index < answer_end_char_idx):
            start_positions.append(0)
            end_positions.append(0)
    
        else:

            # else its start and end token positions
            # here idx indicates index of token
            idx = context_start_idx
            while idx <= context_end_idx and mapping_idx_pairs[idx][0] <= answer_start_char_idx:
                idx += 1
            start_positions.append(idx - 1)  
        

            idx = context_end_idx
            while idx >= context_start_idx and mapping_idx_pairs[idx][1] > answer_end_char_idx:
                idx -= 1
            end_positions.append(idx + 1)
    
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

def preprocess_validation_examples(examples):
    """
    preprocessing validation data
    """
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=512,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")

    base_ids = []

    for i in range(len(inputs["input_ids"])):
        
        # take the base id (ie in cases of overflow happens we get base id)
        base_context_idx = sample_map[i]
        base_ids.append(examples["id"][base_context_idx])
        
        # sequence id indicates the input. 0 for first input and 1 for second input
        # and None for special tokens by default
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        # for Question tokens provide offset_mapping as None
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["base_id"] = base_ids
    return inputs

In [11]:
class DataQA(Dataset):
    def __init__(self, dataset,mode="train"):
        self.mode = mode
        
        
        if self.mode == "train":
            # sampling
            self.dataset = dataset["train"]
            self.data = self.dataset.map(train_data_preprocess,
                                                      batched=True,
                            remove_columns= dataset["train"].column_names)
        
        else:
            self.dataset = dataset["validation"]
            self.data = self.dataset.map(preprocess_validation_examples,
            batched=True,remove_columns = dataset["validation"].column_names,
               )
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):

        out = {}
        example = self.data[idx]
        out['input_ids'] = torch.tensor(example['input_ids'])
        out['attention_mask'] = torch.tensor(example['attention_mask'])

        
        if self.mode == "train":

            out['start_positions'] = torch.unsqueeze(torch.tensor(example['start_positions']),dim=0)
            out['end_positions'] = torch.unsqueeze(torch.tensor(example['end_positions']),dim=0)
            
        return out

In [14]:
trained_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)


train_dataset = DataQA(dataset,mode="train")
val_dataset = DataQA(dataset,mode="validation")



for i,d in enumerate(train_dataset):
    for k in d.keys():
        print(k + ' : ', d[k].shape)
    print('--'*40)

    if i == 3:
        break
        
print('__'*50)

for i,d in enumerate(val_dataset):
    for k in d.keys():
        print(k + ' : ', len(d[k]))
    print('--'*40)
    
    if i == 3:
        break

Map:   0%|          | 0/750 [00:00<?, ? examples/s]

input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
input_ids :  torch.Size([512])
attention_mask :  torch.Size([512])
start_positions :  torch.Size([1])
end_positions :  torch.Size([1])
--------------------------------------------------------------------------------
____________________________________________________________________________________________________
input_ids :  512
attention_mask :  

In [15]:
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=2,
)
eval_dataloader = DataLoader(
    val_dataset, collate_fn=default_data_collator, batch_size=2
)

for batch in train_dataloader:
   print(batch['input_ids'].shape)
   print(batch['attention_mask'].shape)
   print(batch['start_positions'].shape)
   print(batch['end_positions'].shape)
   break

print('---'*20)

for batch in eval_dataloader:
   print(batch['input_ids'].shape)
   print(batch['attention_mask'].shape)
   break

torch.Size([2, 512])
torch.Size([2, 512])
torch.Size([2, 1])
torch.Size([2, 1])
------------------------------------------------------------
torch.Size([2, 512])
torch.Size([2, 512])


In [16]:
from transformers import DistilBertForQuestionAnswering
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Available device: {device}')

checkpoint =  "distilbert-base-uncased"
model = DistilBertForQuestionAnswering.from_pretrained(checkpoint)
model = model.to(device)

Available device: cpu


Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [29]:
optimizer = AdamW(model.parameters(), lr=2e-5)

epochs = 3

# Total number of training steps is [number of batches] x [number of epochs]. 
# (Note that this is not the same as the number of training samples).
total_steps = len(train_dataloader) * epochs
print(total_steps)


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

11265


In [30]:
# we need processed validation data to get offsets at the time of evaluation
validation_processed_dataset = dataset["validation"].map(preprocess_validation_examples, batched=True,
                                                         remove_columns = dataset["validation"].column_names,)

In [32]:
def predict_answers_and_evaluate(start_logits,end_logits,eval_set,examples):
    """
    make predictions 
    Args:
    start_logits : strat_position prediction logits
    end_logits: end_position prediction logits
    eval_set: processed val data
    examples: unprocessed val data with context text
    """
    # appending all id's corresponding to the base context id
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(eval_set):
        example_to_features[feature["base_id"]].append(idx)

    n_best = 20
    max_answer_length = 30
    predicted_answers = []

    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # looping through each sub contexts corresponding to a context and finding
        # answers
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = eval_set["offset_mapping"][feature_index]
        
            # sorting the predictions of all hidden states and taking best n_best prediction
            # means taking the index of top 20 tokens
            start_indexes = np.argsort(start_logit).tolist()[::-1][:n_best]
            end_indexes = np.argsort(end_logit).tolist()[::-1][:n_best]
        
    
            for start_index in start_indexes:
                for end_index in end_indexes:
                
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length.
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                       ):
                        continue

                    answers.append({
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                        })

    
            # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})
    
    metric = evaluate.load("squad")

    theoretical_answers = [
            {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]
    
    metric_ = metric.compute(predictions=predicted_answers, references=theoretical_answers)
    return predicted_answers,metric_

In [33]:
import random,time
import numpy as np

# to reproduce results
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


#storing all training and validation stats
stats = []


#to measure total training time
total_train_time_start = time.time()

for epoch in range(epochs):
    print(' ')
    print(f'=====Epoch {epoch + 1}=====')
    print('Training....')
     
    # ===============================
    #    Train
    # ===============================   
    # measure how long training epoch takes
    t0 = time.time()
     
    training_loss = 0
    # loop through train data
    model.train()
    for step,batch in enumerate(train_dataloader):
         
        # we will print train time in every 40 epochs
        if step%40 == 0 and not step == 0:
              elapsed_time = format_time(time.time() - t0)
              # Report progress.
              print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed_time))

         
       
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
            


        #set gradients to zero
        model.zero_grad()

        result = model(input_ids = input_ids, 
                        attention_mask = attention_mask,
                        start_positions = start_positions,
                        end_positions = end_positions,
                        return_dict=True)
         
        loss = result.loss
    
        #accumulate the loss over batches so that we can calculate avg loss at the end
        training_loss += loss.item()      

        #perform backward prorpogation
        loss.backward()

        # update the gradients
        optimizer.step()

    # calculate avg loss
    avg_train_loss = training_loss/len(train_dataloader) 
 
    # calculates training time
    training_time = format_time(time.time() - t0)
     
    
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))
    
    
    # ===============================
    #    Validation
    # ===============================
     
    print("")
    print("Running Validation...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    model.eval()
     

    start_logits,end_logits = [],[]
    for step,batch in enumerate(eval_dataloader):
         
       
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

         
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():  
             result = model(input_ids = input_ids, 
                        attention_mask = attention_mask,return_dict=True)
        


        start_logits.append(result.start_logits.cpu().numpy())
        end_logits.append(result.end_logits.cpu().numpy())
   

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    # start_logits = start_logits[: len(val_dataset)]
    # end_logits = end_logits[: len(val_dataset)]




    # calculating metrics
    answers,metrics_ = predict_answers_and_evaluate(start_logits,end_logits,validation_processed_dataset,dataset["validation"])
    print(f'Exact match: {metrics_["exact_match"]}, F1 score: {metrics_["f1"]}')


    print('')
    # Measure how long the validation run took.
    validation_time = format_time(time.time() - t0)

    print("  Validation took: {:}".format(validation_time))

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_train_time_start)))

 
=====Epoch 1=====
Training....
  Batch    40  of  3,755.    Elapsed: 0:03:37.
  Batch    80  of  3,755.    Elapsed: 0:07:23.
  Batch   120  of  3,755.    Elapsed: 0:11:09.
  Batch   160  of  3,755.    Elapsed: 0:14:56.
  Batch   200  of  3,755.    Elapsed: 0:18:44.
  Batch   240  of  3,755.    Elapsed: 0:22:30.
  Batch   280  of  3,755.    Elapsed: 0:26:20.
  Batch   320  of  3,755.    Elapsed: 0:30:09.
  Batch   360  of  3,755.    Elapsed: 0:33:52.
  Batch   400  of  3,755.    Elapsed: 0:42:42.
  Batch   440  of  3,755.    Elapsed: 0:46:19.
  Batch   480  of  3,755.    Elapsed: 0:50:44.
  Batch   520  of  3,755.    Elapsed: 7:35:42.
  Batch   560  of  3,755.    Elapsed: 7:38:53.
  Batch   600  of  3,755.    Elapsed: 7:42:07.
  Batch   640  of  3,755.    Elapsed: 7:45:22.
  Batch   680  of  3,755.    Elapsed: 7:48:44.
  Batch   720  of  3,755.    Elapsed: 7:52:22.
  Batch   760  of  3,755.    Elapsed: 7:56:00.
  Batch   800  of  3,755.    Elapsed: 7:59:38.
  Batch   840  of  3,755.  

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]

Exact match: 25.733333333333334, F1 score: 62.46494776351328

  Validation took: 0:40:28
 
=====Epoch 2=====
Training....
  Batch    40  of  3,755.    Elapsed: 0:03:48.
  Batch    80  of  3,755.    Elapsed: 0:07:45.
  Batch   120  of  3,755.    Elapsed: 0:11:42.
  Batch   160  of  3,755.    Elapsed: 0:15:39.
  Batch   200  of  3,755.    Elapsed: 0:19:36.
  Batch   240  of  3,755.    Elapsed: 0:23:30.
  Batch   280  of  3,755.    Elapsed: 0:27:23.
  Batch   320  of  3,755.    Elapsed: 0:31:07.
  Batch   360  of  3,755.    Elapsed: 0:34:51.
  Batch   400  of  3,755.    Elapsed: 0:38:42.
  Batch   440  of  3,755.    Elapsed: 0:42:34.
  Batch   480  of  3,755.    Elapsed: 0:46:24.
  Batch   520  of  3,755.    Elapsed: 0:50:20.
  Batch   560  of  3,755.    Elapsed: 0:57:55.
  Batch   600  of  3,755.    Elapsed: 1:01:44.
  Batch   640  of  3,755.    Elapsed: 1:05:31.
  Batch   680  of  3,755.    Elapsed: 1:09:18.
  Batch   720  of  3,755.    Elapsed: 1:13:04.
  Batch   760  of  3,755.    Ela

  Batch 3,040  of  3,755.    Elapsed: 3:11:50.
  Batch 3,080  of  3,755.    Elapsed: 3:14:17.
  Batch 3,120  of  3,755.    Elapsed: 3:16:42.
  Batch 3,160  of  3,755.    Elapsed: 3:19:01.
  Batch 3,200  of  3,755.    Elapsed: 3:21:25.
  Batch 3,240  of  3,755.    Elapsed: 3:23:50.
  Batch 3,280  of  3,755.    Elapsed: 3:26:16.
  Batch 3,320  of  3,755.    Elapsed: 3:28:48.
  Batch 3,360  of  3,755.    Elapsed: 3:31:07.
  Batch 3,400  of  3,755.    Elapsed: 3:33:25.
  Batch 3,440  of  3,755.    Elapsed: 3:35:50.
  Batch 3,480  of  3,755.    Elapsed: 3:38:14.
  Batch 3,520  of  3,755.    Elapsed: 3:40:39.
  Batch 3,560  of  3,755.    Elapsed: 3:43:04.
  Batch 3,600  of  3,755.    Elapsed: 3:45:31.
  Batch 3,640  of  3,755.    Elapsed: 3:47:57.
  Batch 3,680  of  3,755.    Elapsed: 3:50:18.
  Batch 3,720  of  3,755.    Elapsed: 3:52:36.

  Average training loss: 0.38
  Training epoch took: 3:54:35

Running Validation...
Exact match: 24.933333333333334, F1 score: 62.120821474024766

  Vali

In [34]:
question = "How many parameters does BERT-large have?"
context = "BERT-large is really big... it has 24-layers and an embedding size of 1,024, for a total of 340M parameters! Altogether it is 1.34GB, so expect it to take a couple minutes to download to your Colab instance."

input_ids = tokenizer.encode(question, context)
print (f'We have about {len(input_ids)} tokens generated')

tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(" ")
print('Some examples of token-input_id pairs:')

for i, (token,inp_id) in enumerate(zip(tokens,input_ids)):
    print(token,":",inp_id)

We have about 70 tokens generated
 
Some examples of token-input_id pairs:
[CLS] : 101
how : 2129
many : 2116
parameters : 11709
does : 2515
bert : 14324
- : 1011
large : 2312
have : 2031
? : 1029
[SEP] : 102
bert : 14324
- : 1011
large : 2312
is : 2003
really : 2428
big : 2502
. : 1012
. : 1012
. : 1012
it : 2009
has : 2038
24 : 2484
- : 1011
layers : 9014
and : 1998
an : 2019
em : 7861
##bed : 8270
##ding : 4667
size : 2946
of : 1997
1 : 1015
, : 1010
02 : 6185
##4 : 2549
, : 1010
for : 2005
a : 1037
total : 2561
of : 1997
340 : 16029
##m : 2213
parameters : 11709
! : 999
altogether : 10462
it : 2009
is : 2003
1 : 1015
. : 1012
34 : 4090
##gb : 18259
, : 1010
so : 2061
expect : 5987
it : 2009
to : 2000
take : 2202
a : 1037
couple : 3232
minutes : 2781
to : 2000
download : 8816
to : 2000
your : 2115
cola : 15270
##b : 2497
instance : 6013
. : 1012
[SEP] : 102


In [35]:
sep_idx = tokens.index('[SEP]')

# we will provide including [SEP] token which seperates question from context and 1 for rest.
token_type_ids = [0 for i in range(sep_idx+1)] + [1 for i in range(sep_idx+1,len(tokens))]
print(token_type_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [40]:
# Run our example through the model.
out = model(torch.tensor([input_ids]))

start_logits,end_logits = out['start_logits'],out['end_logits']
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_logits)
answer_end = torch.argmax(end_logits)

ans = ''.join(tokens[answer_start:answer_end])
print('Predicted answer:', ans)

Predicted answer: 340##m
