In [None]:
!git clone https://github.com/Senyu-T/unifiedqa

In [None]:
cd unifiedqa/bart

In [None]:
!chmod +x download_data.sh; ./download_data.sh

In [5]:
cd data/natural_questions_with_dpr_para/

/content/unifiedqa/bart/data/natural_questions_with_dpr_para


Normalize answers by removing special tokens.

In [39]:
import string
import re

def normalize_answer(s):
  def remove_articles(text):
    return re.sub(r'\b(a|an|the)\b', ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def remove_spc_token(s):
  s = s.replace(' \\\'\\\'', ' \'\'')   # double quotation
  s = s.replace('\\\'', '\'')
  s = s.replace(' \'s', '\'s')    # 's
  s = s.replace(' ,', ',')
  return s

In [40]:
# read file, parse into context / question / answer for further data analysis
def read_files(file_name):
  answers = []
  questions = []
  contexts = []
  with open(file_name, 'rb') as inference_in:
    lines = inference_in.readlines()
    for i in range(len(lines)):
      sep = str(lines[i]).split('\\n') 
      questions.append(sep[0][2:-1])
      ans = (sep[1].split('\\t')[-1]).lower()
      ans = normalize_answer(remove_spc_token(ans))  # normalize answers
      answers.append(ans)
      contexts.append(sep[1].split('\\t')[0])
  return answers, questions, contexts

Get Spacy NER model running

In [None]:
!pip3 install spacy-transformers

In [None]:
!python3 -m spacy download en_core_web_trf

In [47]:
import spacy
spacy.require_gpu() 
sp_lg = spacy.load('en_core_web_trf')

In [48]:
def spacy_large_ner(document, ans):
  #print({(ent.text.strip(), ent.label_) for ent in sp_lg(document).ents})
  for ent in sp_lg(document).ents:
    if ans == ent.text.strip():
      return ent.label_

Retrieve the sentence that contains the answer from the context.

In [49]:
import time

In [50]:
def get_tags(answer_list, q_list, context_list):
  soft_tags = []
  hard_tags = []
  start_time = time.time()
  for i in range(len(answer_list)):
    q_tokens = q_list[i].split(' ')
    start = q_tokens[0].replace('\\', '')
    # we use NER to get the answer type, in case NER produce
    # "None" for obvious types, we craft the tag according to
    # the first token
    docs = context_list[i].split('- -') # divide by wiki entries
    ans = answer_list[i]
    ans_type = None
    # run NER based on retrieved document contexts
    for doc in docs:
      if doc.find(ans):
        ans_type = spacy_large_ner(doc, ans)
        if ans_type is not None:
          break

    # in case NER model fails to find type based on context
    # find answer type based on the answer itself
    if ans_type is None:
      ans_type = spacy_large_ner(ans, ans)
    

    hard_ans_type = None
    if start == "who": # ignores ORG, just use 
      hard_ans_type = 'PERSON'
    elif start == "what" and q_tokens[1] == "number":
      hard_ans_type = 'CARDINAL'
    elif start == "how" and q_tokens[1] == "many":
      hard_ans_type = 'CARDINAL'
    elif start == "when":
      hard_ans_type = "DATE"
    elif start == "where":
      hard_ans_type = "LOC"
    else:
      hard_ans_type = "OTHERS"

    # if NER fails again, we use hard type as soft type
    if ans_type is None:
      ans_type = hard_ans_type
    # append anser type
    soft_tags.append(ans_type)
    hard_tags.append(hard_ans_type)
    if i % 100 == 0:
      print(f"lines {i:5d} out of {len(answer_list)} done in {time.time() - start_time:.2f} seconds")
  assert(len(soft_tags) == len(answer_list))
  assert(len(hard_tags) == len(answer_list))
  return hard_tags, soft_tags

Now we conduct the experiments

In [51]:
dev_answers, dev_questions, dev_contexts = read_files("dev.tsv")

In [52]:
sample_a = dev_answers[40:51]
sample_c = dev_contexts[40:51]
sample_q = dev_questions[40:51]

sample_ht, sample_st = get_tags(sample_a,sample_q,sample_c)
print(sample_ht)
print(sample_st)

lines     0 out of 11 done in 0.04 seconds
['PERSON', 'PERSON', 'OTHERS', 'OTHERS', 'PERSON', 'OTHERS', 'OTHERS', 'PERSON', 'PERSON', 'PERSON', 'PERSON']
['PERSON', 'GPE', 'ORG', 'PERSON', 'PERSON', 'OTHERS', 'OTHERS', 'PERSON', 'PERSON', 'PERSON', 'PERSON']


In [33]:
print(sample_a)

['andy griffith', 'brazil', 'walmart', 'mcdonalds', 'cyndi lauper', 'ymir fritz', 'artax', 'alexander graham bell', 'drake', 'cece', 'blair']


In [32]:
print(sample_q)

['who is the old man in waiting on a woman?', 'who have won the world cup the most times?', 'list of companies with highest number of employees?', 'list of companies with highest number of employees?', 'who sang the original version of true colors?', 'in attack on titan who is the female titan?', 'what is the horses name in never-ending story?', 'who made the first telephone in the world?', 'who is charles off of pretty little liars?', 'who is charles off of pretty little liars?', 'who got pregnant in gossip girl season 5?']


Now we get the writer works

In [53]:
def write_res(tag_list, output_f):
  with open(output_f, 'wt') as f:
    f.write('\n'.join(tag_list))


In [None]:
dev_ht, dev_st = get_tags(dev_answers, dev_questions, dev_contexts)

In [56]:
write_res(dev_st, "dev_soft_tag.tsv")
write_res(dev_ht, "dev_hard_tag.tsv")

In [57]:
train_answers, train_questions, train_contexts = read_files("train.tsv")

In [None]:
train_ht, train_st = get_tags(train_answers, train_questions, train_contexts)

lines     0 out of 96676 done in 0.15 seconds
lines   100 out of 96676 done in 10.34 seconds
lines   200 out of 96676 done in 20.13 seconds
lines   300 out of 96676 done in 29.70 seconds
lines   400 out of 96676 done in 39.23 seconds
lines   500 out of 96676 done in 49.75 seconds
lines   600 out of 96676 done in 60.67 seconds
lines   700 out of 96676 done in 71.36 seconds
lines   800 out of 96676 done in 81.82 seconds
lines   900 out of 96676 done in 93.37 seconds
lines  1000 out of 96676 done in 104.09 seconds
lines  1100 out of 96676 done in 114.48 seconds
lines  1200 out of 96676 done in 125.19 seconds
lines  1300 out of 96676 done in 135.84 seconds
lines  1400 out of 96676 done in 146.38 seconds
lines  1500 out of 96676 done in 157.38 seconds
lines  1600 out of 96676 done in 167.12 seconds
lines  1700 out of 96676 done in 177.11 seconds
lines  1800 out of 96676 done in 187.38 seconds
lines  1900 out of 96676 done in 197.87 seconds
lines  2000 out of 96676 done in 207.60 seconds
lin