In [30]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns


import torch
import torch.nn as nn
import torch.functional as F

import random

In [33]:
class Dataset():
    def __init__(self,cqa_data):
        self.cqa_data = cqa_data

        cqa_data = pd.read_csv(cqa_data)
        data = pd.DataFrame()
        data["answer"]   = cqa_data["answer"]
        data["question"] = cqa_data["question"]
        data["context"]  = cqa_data["context"]
        data["str_idx"]  = cqa_data["str_idx"]
        data["end_idx"]  = cqa_data["end_idx"]
        self.cqa_data = data

    def __len__(self):
        return len(self.cqa_data)

    def __getitem__(self, id):
        item     = self.cqa_data.loc[id]
        answer   = item["answer"]
        question = item["question"]
        context  = item["context"]
        str_idx  = item["str_idx"]
        end_idx  = item["end_idx"]
        return {"answer":answer, "question":question, "context":context, "str_idx": str_idx, "end_idx": end_idx}

In [34]:
cqa_data = 'data.csv'

In [35]:
dataset = Dataset(cqa_data=cqa_data)
print("length:",len(dataset))
dataset.cqa_data

length: 1500


Unnamed: 0,answer,question,context,str_idx,end_idx
0,2003,When did Beyonce leave Destiny's Child and bec...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,526,530
1,late 1990s,In which decade did Beyonce become famous?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,276,286
2,Destiny's Child,In what R&B group was she the lead singer?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,320,335
3,Mathew Knowles,Who managed the Destiny's Child group?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,360,374
4,Dangerously in Love,What was the first album Beyoncé released as a...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,505,524
...,...,...,...,...,...
1495,game or prey,Hunted species are usually referred to as what?,Hunting is the practice of killing or trapping...,469,481
1496,Hunting,What is it called to kill or trap an animal?,Hunting is the practice of killing or trapping...,0,7
1497,"food, recreation, to remove predators",Why do humans hunt?,Hunting is the practice of killing or trapping...,185,222
1498,2010s,When was lawful hunting distinguished from poa...,Hunting is the practice of killing or trapping...,295,300


In [36]:
from transformers import DistilBertTokenizerFast

model_name = 'distilbert-base-cased-distilled-squad'

tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

In [37]:
def tokenize_align(example):
    encoding = tokenizer(example['context'], example['question'], truncation=True, padding='max_length', max_length=512)
    start_positions = encoding.char_to_token(int(example['str_idx']))
    end_positions = encoding.char_to_token(int(example['end_idx'])-1)
    if start_positions is None:
        start_positions = tokenizer.model_max_length
    if end_positions is None:
        end_positions = tokenizer.model_max_length
    return {'input_ids': encoding['input_ids'],
          'attention_mask': encoding['attention_mask'],
          'start_positions': start_positions,
          'end_positions': end_positions}

In [38]:
tokenized_df = pd.DataFrame(columns=['input_ids','attention_mask','start_positions','end_positions'])

In [39]:
for i in range(len(dataset)):
    tokens_map = tokenize_align(dataset[i])
    tokenized_df.loc[i] = tokens_map

In [40]:
dataset.cqa_data = pd.concat([dataset.cqa_data, tokenized_df], axis=1)

In [41]:
dataset.cqa_data.head()

Unnamed: 0,answer,question,context,str_idx,end_idx,input_ids,attention_mask,start_positions,end_positions
0,2003,When did Beyonce leave Destiny's Child and bec...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,526,530,"[101, 24041, 144, 22080, 25384, 118, 5007, 113...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",125,125
1,late 1990s,In which decade did Beyonce become famous?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,276,286,"[101, 24041, 144, 22080, 25384, 118, 5007, 113...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",63,64
2,Destiny's Child,In what R&B group was she the lead singer?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,320,335,"[101, 24041, 144, 22080, 25384, 118, 5007, 113...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",75,78
3,Mathew Knowles,Who managed the Destiny's Child group?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,360,374,"[101, 24041, 144, 22080, 25384, 118, 5007, 113...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",86,88
4,Dangerously in Love,What was the first album Beyoncé released as a...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,505,524,"[101, 24041, 144, 22080, 25384, 118, 5007, 113...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",120,123


In [42]:
dict_data = dataset.cqa_data.to_dict(orient='records')
dict_data[200]

{'answer': 'one hour',
 'question': 'What period of time do we set our clocks forward in DST?',
 'context': 'Daylight saving time (DST) or summer time is the practice of advancing clocks during summer months by one hour so that in the evening daylight is experienced an hour longer, while sacrificing normal sunrise times. Typically, regions with summer time adjust clocks forward one hour close to the start of spring and adjust them backward in the autumn to standard time.',
 'str_idx': 102,
 'end_idx': 110,
 'input_ids': [101,
  2295,
  4568,
  7740,
  1159,
  113,
  18448,
  1942,
  114,
  1137,
  2247,
  1159,
  1110,
  1103,
  2415,
  1104,
  11120,
  24998,
  1219,
  2247,
  1808,
  1118,
  1141,
  2396,
  1177,
  1115,
  1107,
  1103,
  3440,
  13258,
  1110,
  4531,
  1126,
  2396,
  2039,
  117,
  1229,
  21718,
  1665,
  2047,
  21361,
  1158,
  2999,
  23859,
  1551,
  119,
  16304,
  117,
  4001,
  1114,
  2247,
  1159,
  14878,
  24998,
  1977,
  1141,
  2396,
  1601,
  1106,

# Training :

In [43]:
from datasets import Dataset

dataset = Dataset.from_pandas(pd.DataFrame(dict_data))

In [44]:
Data_splited = dataset.train_test_split(test_size=0.33)

In [46]:
train_dataset = Data_splited['train']
val_dataset = Data_splited['test']

In [48]:
from torch.utils.data import DataLoader

columns_to_return = ['input_ids','attention_mask', 'start_positions', 'end_positions']
train_dataset.set_format(type='pt', columns=columns_to_return)
val_dataset.set_format(type='pt', columns=columns_to_return)

In [49]:
train_dataset[0]

{'input_ids': tensor([  101, 13359,  2744,  1181,  2744,  4907,  8465, 22964,  6709,   113,
           120,   384, 28289,  1186, 28290,  1643, 28200,  1179,   120,   132,
          1497, 17238,   131,   164,   175, 28287,  1162,   119,  1260,   119,
           366,  4847,   175, 28287, 28275, 28311,   119,   188,  3624,   368,
         28276,   119,   185, 28279, 28311,   166,   132,  1659,  1428,  1137,
           122,  1345, 11831,   782,  1542,  1357,  8688,   114,   117,  1255,
         20685,  2692, 17252,  3720,  3171,  1377, 22964,  6709,   117,   164,
           183,   122,   166,  1108,   170,  3129,  1105,  1497,   113,  1118,
          9709,  1105,  3485,  1104,  1401,   114,  3996,  1105,   170,   191,
         25074, 11848,  7301,  9003,  1104,  1103, 20359,  3386,   117,  1150,
          1724,  3120,  1111,  1103,  3444,  3267,   119,  1124,  3388,  1105,
          1144,  4441,  1231,  2728,  6540,  4529,  1112,  1141,  1104,  1103,
          2020,  4992,  1104,  1117,  3

In [50]:
from sklearn.metrics import f1_score

def compute_metrics(pred):
    start_labels = pred.label_ids[0]
    start_preds = pred.predictions[0].argmax(-1)
    end_labels = pred.label_ids[1]
    end_preds = pred.predictions[1].argmax(-1)
    
    f1_start = f1_score(start_labels, start_preds, average='macro')
    f1_end = f1_score(end_labels, end_preds, average='macro')
    
    return {
        'f1_start': f1_start,
        'f1_end': f1_end,
    }

In [51]:
from transformers import DistilBertForQuestionAnswering, BertConfig

pytorch_model = DistilBertForQuestionAnswering.from_pretrained(model_name)

In [52]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='results',         
    overwrite_output_dir=True,
    num_train_epochs=3,
    gradient_accumulation_steps=2,            
    per_device_train_batch_size=32,  
    per_device_eval_batch_size=32,  
    warmup_steps=20,                
    weight_decay=0.01,              
    logging_dir=None,       
    logging_steps=50
)

trainer = Trainer(
    model=pytorch_model,                
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

  0%|          | 0/48 [00:00<?, ?it/s]

{'train_runtime': 4794.1973, 'train_samples_per_second': 0.629, 'train_steps_per_second': 0.01, 'train_loss': 1.563680648803711, 'epoch': 3.0}


TrainOutput(global_step=48, training_loss=1.563680648803711, metrics={'train_runtime': 4794.1973, 'train_samples_per_second': 0.629, 'train_steps_per_second': 0.01, 'train_loss': 1.563680648803711, 'epoch': 3.0})

In [54]:
trainer.evaluate(val_dataset)

  0%|          | 0/16 [00:00<?, ?it/s]

{'eval_loss': 1.128614068031311,
 'eval_f1_start': 0.5770132040701306,
 'eval_f1_end': 0.5877568125838022,
 'eval_runtime': 253.2324,
 'eval_samples_per_second': 1.955,
 'eval_steps_per_second': 0.063,
 'epoch': 3.0}

In [83]:
#testing the model 

pytorch_model

question, text = 'where was i yesterday?','Yesterday imad and I were at the club des pins.'

input_dict = tokenizer(text, question, return_tensors='pt')

input_ids = input_dict['input_ids']
attention_mask = input_dict['attention_mask']

outputs = pytorch_model(input_ids, attention_mask=attention_mask)

start_logits = outputs[0]
end_logits = outputs[1]

all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[torch.argmax(start_logits, 1)[0] : torch.argmax(end_logits, 1)[0]+1])

print(question, answer.capitalize())

where was i yesterday? Club des pins
