In [12]:
import json
from pathlib import Path

def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_squad(r'D:\software\github\GZK_Code\XAI\2022.03.03\squad\train-v2.0.json')
val_contexts, val_questions, val_answers = read_squad(r'D:\software\github\GZK_Code\XAI\2022.03.03\squad\dev-v2.0.json')


In [13]:
train_contexts, train_questions, train_answers = train_contexts[0:5], train_questions[0:5], train_answers[0:5]
val_contexts, val_questions, val_answers = val_contexts[0:5], val_questions[0:5], val_answers[0:5]
sep_train_contexts = []
sep_train_questions = []
sep_train_answers = []

In [14]:
train_questions

['When did Beyonce start becoming popular?',
 'What areas did Beyonce compete in when she was growing up?',
 "When did Beyonce leave Destiny's Child and become a solo singer?",
 'In what city and state did Beyonce  grow up? ',
 'In which decade did Beyonce become famous?']

In [15]:
import nltk as tk
import re
null_answer = {'text': '[NULL]', 'answer_start': 0}
for i in range(5):
    tokens = tk.sent_tokenize(train_contexts[i])
    for token in tokens:
        if train_answers[i]['text'] in token:
            sep_train_contexts.append(token)
            answer_start = re.search(train_answers[i]['text'], token)
            answer = {'text': train_answers[i]['text'], 'answer_start':  answer_start.span()[0]}
            sep_train_answers.append(answer)
            sep_train_questions.append(train_questions[i])
        else:
            sep_train_contexts.append('[NULL]' + token)
            sep_train_answers.append(null_answer)
            sep_train_questions.append(train_questions[i])
            


In [16]:
sep_train_questions

['When did Beyonce start becoming popular?',
 'When did Beyonce start becoming popular?',
 'When did Beyonce start becoming popular?',
 'When did Beyonce start becoming popular?',
 'What areas did Beyonce compete in when she was growing up?',
 'What areas did Beyonce compete in when she was growing up?',
 'What areas did Beyonce compete in when she was growing up?',
 'What areas did Beyonce compete in when she was growing up?',
 "When did Beyonce leave Destiny's Child and become a solo singer?",
 "When did Beyonce leave Destiny's Child and become a solo singer?",
 "When did Beyonce leave Destiny's Child and become a solo singer?",
 "When did Beyonce leave Destiny's Child and become a solo singer?",
 'In what city and state did Beyonce  grow up? ',
 'In what city and state did Beyonce  grow up? ',
 'In what city and state did Beyonce  grow up? ',
 'In what city and state did Beyonce  grow up? ',
 'In which decade did Beyonce become famous?',
 'In which decade did Beyonce become famous?'

In [17]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(sep_train_answers, sep_train_contexts)
add_end_idx(val_answers, val_contexts)

In [50]:
from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(sep_train_contexts, sep_train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

In [49]:
sep_train_answers[5]

{'text': 'singing and dancing', 'answer_start': 60, 'answer_end': 79}

In [51]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    print("len len(answers) : ",len(answers))
    for i in range(len(answers)):
        print(i, encodings.char_to_token(i, answers[i]['answer_start']), answers[i]['answer_start'])

        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))

        
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
        # if None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, sep_train_answers)
add_token_positions(val_encodings, val_answers)

len len(answers) :  20
0 1 0
1 25 122
2 1 0
3 1 0
4 1 0
5 13 60
6 1 0
7 1 0
8 1 0
9 1 0
10 1 0
11 17 76
12 1 0
13 5 19
14 1 0
15 1 0
16 1 0
17 27 129
18 1 0
19 1 0
len len(answers) :  5
0 41 159
1 41 159
2 41 159
3 41 159
4 28 94


In [52]:
import torch

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [53]:
from transformers import DistilBertForQuestionAnswering
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

In [54]:
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
model.train()

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optim = AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()

model.eval()
torch.save(model, "DistilBertForQuestionAnswering.pth")



DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            

In [55]:
torch.save(model, "model.pth")

In [56]:
torch.load("model.pth")

DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            