# Finetuning the Reader


In [None]:
# # Transformers installation
! pip install transformers datasets
# ! pip install sentencepiece
# # To install from source instead of the last release, comment the command above and uncomment the following one.
# # ! pip install git+https://github.com/huggingface/transformers.git

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import pipeline

# Question answering

## Load SQuAD dataset

In [None]:
from datasets import load_dataset

squad = load_dataset("squad", split="train[:10000]")



In [None]:
squad = squad.train_test_split(test_size=0.2)

In [None]:
len(squad["train"])

8000

## QA reader

In [None]:
qa = pipeline("question-answering", model = 'vasudevgupta/bigbird-roberta-natural-questions')

In [None]:
squad['test'][0]

In [None]:
contexts = [squad['test'][i]['context'] for i in range(len(squad['test']))]
questions = [squad['test'][i]['question'] for i in range(len(squad['test']))]
answers = [squad['test'][i]['answers']['text'][0] for i in range(len(squad['test']))]

In [None]:
results = qa(question = questions, context = contexts, batch_size = 32)

In [None]:
import re
import string
import collections

#@title Eval functions
def get_f1(answer, pred_answer):
  answer_tokens = answer.strip(' ').split(' ')
  pred_answer_tokens = pred_answer.strip(' ').split(' ')
  m = len(answer_tokens)
  n = len(pred_answer_tokens)
  k = 0
  f1 = 0
  for token in pred_answer_tokens:
    if token != '' and token in answer:
      k += 1
  if k:
    pr = k / n
    re = k / m
    f1 = (2 * pr * re) / (pr + re)
  return f1

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def calc_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

## Evaluation (without fine-tuning)

In [None]:
f1 = []
for i in range(len(results)):
  f1 += [calc_f1(answers[i], results[i]['answer'])]

print(np.mean(f1))

0.8667531061356916


## Preprocess

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("vasudevgupta/bigbird-roberta-natural-questions")

In [None]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [None]:
tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

## Train

In [None]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained("vasudevgupta/bigbird-roberta-natural-questions") # , from_tf = True)

In [None]:
training_args = TrainingArguments(
    output_dir="my_awesome_qa_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_squad["train"],
    eval_dataset=tokenized_squad["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

## Inference - After fine-tuning

In [None]:
model = model.to('cpu')

In [None]:
qa_finetuned = pipeline("question-answering", model = model, tokenizer = tokenizer)

In [None]:
results = qa_finetuned(question = questions, context = contexts, batch_size = 32)

In [None]:
f1 = []
for i in range(len(results)):
  f1 += [calc_f1(answers[i], results[i]['answer'])]

print(np.mean(f1))

You can also manually replicate the results of the `pipeline` if you'd like:

Tokenize the text and return PyTorch tensors:

In [None]:
# from transformers import AutoTokenizer

# # tokenizer = AutoTokenizer.from_pretrained("my_awesome_qa_model")
# inputs = tokenizer(question, context, return_tensors="pt")

In [None]:
# from transformers import AutoModelForQuestionAnswering
# import torch 

# # model = AutoModelForQuestionAnswering.from_pretrained("my_awesome_qa_model")
# with torch.no_grad():
#     outputs = model(**inputs)

In [None]:
# answer_start_index = outputs.start_logits.argmax()
# answer_end_index = outputs.end_logits.argmax()

In [None]:
# predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
# tokenizer.decode(predict_answer_tokens)

## Submission

In [None]:
paras = pd.read_csv('paragraphs.csv')
qa = pd.read_csv('question_answers.csv')

In [None]:
paras.head()

Unnamed: 0,id,paragraph,theme
0,1,The iPod is a line of portable media players a...,IPod
1,2,"Like other digital music players, iPods can se...",IPod
2,3,Apple's iTunes software (and other alternative...,IPod
3,4,"Before the release of iOS 5, the iPod branding...",IPod
4,5,"In mid-2015, a new model of the iPod Touch was...",IPod


In [None]:
themes = paras['theme'].unique()

In [None]:
len(themes)

30

In [None]:
qa.head()

Unnamed: 0,question,theme,paragraph_id,answer
0,Which company produces the iPod?,IPod,1,Apple
1,When was the original iPod released?,IPod,1,"October 23, 2001"
2,How many different types of iPod are currently...,IPod,1,three
3,What kind of device is the iPod?,IPod,1,portable media players
4,The iPod Touch uses what kind of interface?,IPod,1,touchscreen


In [None]:
qns_theme = dict()
paras_theme = dict()
ans_theme = dict()
for theme in themes:
  qns_sub = qa[qa['theme'] == theme]
  qns_theme[theme] = qns_sub['question'].to_list()
  paras_theme[theme] = [paras[paras['id'] == idx].iloc[0, 1] for idx in qns_sub['paragraph_id'].to_list()]
  ans_theme[theme] = qns_sub['answer'].to_list()

In [None]:
reader = pipeline("question-answering", model = 'distilbert-base-uncased-distilled-squad')

In [None]:
results = dict()
pred_ans_theme = dict()
for theme in themes:
  pred_ans_theme[theme] = reader(question = qns_theme[theme], context = paras_theme[theme], batch_size = 32)
  f1 = [calc_f1(ans_theme[theme][i], pred_ans_theme[theme][i]['answer']) for i in range(len(pred_ans_theme[theme]))]
  results[theme] = np.mean(f1)

In [None]:
results # distilbert-base-uncased-distilled-squad

{'IPod': 0.8959154053493675,
 '2008_Sichuan_earthquake': 0.9132090456424258,
 'Wayback_Machine': 0.9733044733044734,
 'Canadian_Armed_Forces': 0.879115432258808,
 'Cardinal_(Catholicism)': 0.9749418591523854,
 'Human_Development_Index': 0.9697802197802198,
 'Heresy': 0.9080459770114943,
 'Warsaw_Pact': 0.8777777777777778,
 'Materialism': 0.921875,
 'Pub': 0.9447196620583718,
 'Web_browser': 0.9546666666666667,
 'Catalan_language': 0.9018252580752582,
 'Paper': 0.9519607843137255,
 'Adult_contemporary_music': 0.9176190476190477,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # deepset/tinyroberta-squad2

{'IPod': 0.899560554749234,
 '2008_Sichuan_earthquake': 0.9279586786083776,
 'Wayback_Machine': 0.8845598845598847,
 'Canadian_Armed_Forces': 0.869203078911672,
 'Cardinal_(Catholicism)': 0.893474677685204,
 'Human_Development_Index': 0.9406593406593406,
 'Heresy': 0.9655172413793104,
 'Warsaw_Pact': 0.6317460317460317,
 'Materialism': 0.928702731092437,
 'Pub': 0.9857142857142857,
 'Web_browser': 0.9457777777777778,
 'Catalan_language': 0.9183836996336996,
 'Paper': 0.9,
 'Adult_contemporary_music': 0.9178571428571428,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # deepset/minilm-uncased-squad2

{'IPod': 0.9323627237542331,
 '2008_Sichuan_earthquake': 0.921692050093172,
 'Wayback_Machine': 0.9116161616161617,
 'Canadian_Armed_Forces': 0.8879788028063891,
 'Cardinal_(Catholicism)': 0.9289959132064396,
 'Human_Development_Index': 0.8791208791208791,
 'Heresy': 0.9471264367816091,
 'Warsaw_Pact': 0.7666666666666667,
 'Materialism': 0.9733455882352942,
 'Pub': 0.9696725317693059,
 'Web_browser': 0.992,
 'Catalan_language': 0.9285506160506162,
 'Paper': 0.8736694677871147,
 'Adult_contemporary_music': 0.9021428571428571,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # deepset/electra-base-squad2

{'IPod': 0.9646427995484599,
 '2008_Sichuan_earthquake': 0.9798941798941798,
 'Wayback_Machine': 0.9545454545454546,
 'Canadian_Armed_Forces': 0.9420718462823726,
 'Cardinal_(Catholicism)': 0.9195366795366795,
 'Human_Development_Index': 0.9807692307692307,
 'Heresy': 0.9408866995073891,
 'Warsaw_Pact': 0.9444444444444444,
 'Materialism': 0.984375,
 'Pub': 0.9879032258064516,
 'Web_browser': 0.9466666666666668,
 'Catalan_language': 0.9807692307692307,
 'Paper': 1.0,
 'Adult_contemporary_music': 1.0,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # bhadresh-savani/electra-base-squad2

{'IPod': 0.9646427995484599,
 '2008_Sichuan_earthquake': 0.9798941798941798,
 'Wayback_Machine': 0.9545454545454546,
 'Canadian_Armed_Forces': 0.9420718462823726,
 'Cardinal_(Catholicism)': 0.9195366795366795,
 'Human_Development_Index': 0.9807692307692307,
 'Heresy': 0.9408866995073891,
 'Warsaw_Pact': 0.9444444444444444,
 'Materialism': 0.984375,
 'Pub': 0.9879032258064516,
 'Web_browser': 0.9466666666666668,
 'Catalan_language': 0.9807692307692307,
 'Paper': 1.0,
 'Adult_contemporary_music': 1.0,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # deepset/roberta-base-squad2-distilled

{'IPod': 0.9212994159928123,
 '2008_Sichuan_earthquake': 0.9343928955264653,
 'Wayback_Machine': 0.9386724386724388,
 'Canadian_Armed_Forces': 0.8640602217688148,
 'Cardinal_(Catholicism)': 0.94754387293087,
 'Human_Development_Index': 0.9624542124542125,
 'Heresy': 0.9586206896551724,
 'Warsaw_Pact': 0.8222222222222223,
 'Materialism': 0.9397321428571428,
 'Pub': 0.9803379416282643,
 'Web_browser': 0.9386666666666668,
 'Catalan_language': 0.9054410866910867,
 'Paper': 0.9568627450980391,
 'Adult_contemporary_music': 0.9494505494505495,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

In [None]:
results # deepset/roberta-base-squad2

{'IPod': 0.9015667115902966,
 '2008_Sichuan_earthquake': 0.9207785967603639,
 'Wayback_Machine': 0.876984126984127,
 'Canadian_Armed_Forces': 0.8717619761547798,
 'Cardinal_(Catholicism)': 0.8973508227378196,
 'Human_Development_Index': 0.9569597069597069,
 'Heresy': 0.896551724137931,
 'Warsaw_Pact': 0.6317460317460317,
 'Materialism': 0.8840598739495797,
 'Pub': 0.9607177768468091,
 'Web_browser': 0.9546666666666667,
 'Catalan_language': 0.9034299034299033,
 'Paper': 0.911029411764706,
 'Adult_contemporary_music': 0.9104761904761905,
 'Nanjing': nan,
 'Dialect': nan,
 'Southampton': nan,
 'The_Times': nan,
 'Immunology': nan,
 'Imamah_(Shia_doctrine)': nan,
 'Grape': nan,
 'United_States_dollar': nan,
 'Everton_F.C.': nan,
 'Hard_rock': nan,
 'Great_Plains': nan,
 'Biodiversity': nan,
 'Federal_Bureau_of_Investigation': nan,
 'Mary_(mother_of_Jesus)': nan,
 'Unknown': nan,
 'DevRev': nan}

## Evaluation on generated questions

In [None]:
gen = pd.read_csv('paragraphs_gen.csv')
gen.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,id,Paragraph,Theme,GeneratedQuestions
0,0,0,1,The iPod is a line of portable media players a...,IPod,"['The iPod is what kind of portable device?', ..."
1,1,1,2,"Like other digital music players, iPods can se...",IPod,['What other mobile device is used for externa...
2,2,2,3,Apple's iTunes software (and other alternative...,IPod,"['To what can iTunes be used?', ""What is also ..."
3,3,3,4,"Before the release of iOS 5, the iPod branding...",IPod,['Before iOS how did iPod use a media player?'...
4,4,4,5,"In mid-2015, a new model of the iPod Touch was...",IPod,"[""What is faster than the iPod Touch's predece..."


In [None]:
unk_themes = ['Nanjing', 'Dialect', 'Southampton', 'The_Times', 'Immunology', 'Imamah_(Shia_doctrine)', 'Grape', 'United_States_dollar', 'Everton_F.C.',
              'Hard_rock', 'Great_Plains', 'Biodiversity', 'Federal_Bureau_of_Investigation', 'Mary_(mother_of_Jesus)', 'Unknown', 'DevRev']

In [None]:
from ast import literal_eval
gen['GeneratedQuestions'] = gen['GeneratedQuestions'].apply(literal_eval)

In [None]:
theme = 'Nanjing'
gen_paras = gen[gen['Theme'] == theme]['Paragraph'].to_list()
gen_qns = gen[gen['Theme'] == theme]['GeneratedQuestions'].to_list()

In [None]:
len(gen_paras)

66

In [None]:
idx = 0
len(gen_qns[idx])

171

In [None]:
result = reader(question = gen_qns[idx], context = [gen_paras[idx]] * len(gen_qns[idx]), batch_size = 32)

In [None]:
gen_paras[idx]

'Nanjing ( listen; Chinese: 南京, "Southern Capital") is the city situated in the heartland of lower Yangtze River region in China, which has long been a major centre of culture, education, research, politics, economy, transport networks and tourism. It is the capital city of Jiangsu province of People\'s Republic of China and the second largest city in East China, with a total population of 8,216,100, and legally the capital of Republic of China which lost the mainland during the civil war. The city whose name means "Southern Capital" has a prominent place in Chinese history and culture, having served as the capitals of various Chinese dynasties, kingdoms and republican governments dating from the 3rd century AD to 1949. Prior to the advent of pinyin romanization, Nanjing\'s city name was spelled as Nanking or Nankin. Nanjing has a number of other names, and some historical names are now used as names of districts of the city, and among them there is the name Jiangning (江寧), whose forme

In [None]:
summary = dict()
for i in range(len(gen_qns[idx])):
  temp = [result[i]['answer'] for i in range(len(result))]
  summary[gen_qns[idx][i]] = temp[i]

In [None]:
summary

{'What type of government has Nanjing served?': 'republican',
 "What is Nanjing's short title?": 'Ning',
 'What dynasty became a capital of China in the late 1960s and early 1980s?': 'Jin',
 'What are the past names used as for district names in Nanjing?': 'districts of the city',
 'When did Nanjing become the capital?': 'Jin dynasty',
 'What is now used as names of districts of the city?': 'Jiangning',
 'Nanjing lost what to the civil war in China?': 'the mainland',
 'The name of Nanjing would be what?': 'Southern Capital',
 'For which period of time is the name Nanjing used as the name of the city?': 'Ming dynasty',
 'What had Nanjing served as the capital of?': 'Chinese dynasties, kingdoms and republican governments',
 'When being the capital of a state, what word is adopted as the abbreviation of Nanjing?': 'ROC, Jing',
 "What part of the city is in the People's Republic of China?": 'southern',
 'What dynasty had the name Nanjing attributed to?': 'Ming dynasty',
 'Along with politi