In [None]:
import os
import sys

root_dir = os.path.abspath(os.path.join(os.getcwd(),os.pardir,os.pardir))
sys.path.append(os.path.join(root_dir,'src/laboro_distilbert'))

from tqdm.notebook import tqdm
import tokenization
from transformers import BatchEncoding
from transformers import DistilBertForQuestionAnswering, Trainer, TrainingArguments
import torch
import pickle
import collections

import json
from pathlib import Path
import math

In [None]:
ddqa_path = os.path.join(root_dir,'data/ddqa/RC-QA')
train_path = os.path.join(ddqa_path,'DDQA-1.0_RC-QA_train.json')
tokenized_train_path = os.path.join(ddqa_path,'tokenized_DDQA-1.0_RC-QA_train.json')
dev_path = os.path.join(ddqa_path,'DDQA-1.0_RC-QA_dev.json')
tokenized_dev_path = os.path.join(ddqa_path, 'tokenized_DDQA-1.0_RC-QA_dev.json')


In [None]:
def read_squad(path, tokenized_path, is_training):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    tokenized_path = Path(tokenized_path)
    with open(tokenized_path, 'rb') as tokenized_f:
        tokenized_squad_dict = json.load(tokenized_f)

    contexts = []
    questions = []
    answers = []
    tokenized_contexts = []
    tokenized_questions = []
    g_i = 0
    for group in squad_dict['data']:
        p_i = 0
        for passage in group['paragraphs']:
            context = passage['context'].replace(" ", ".").replace("…", ".")
            doc_tokens = tokenization.convert_to_unicode(context)
            tokenized_context = tokenized_squad_dict['data'][g_i]['paragraphs'][p_i]['context']
            tokenized_doc_tokens = tokenization.convert_to_unicode(tokenized_context)
            q_i = 0
            for qa in passage['qas']:
                question = qa['question']
                tokenized_question = tokenized_squad_dict['data'][g_i]['paragraphs'][p_i]['qas'][q_i]['question']
                if is_training:
                    is_impossible = qa["is_impossible"]
                    if not is_impossible:
                        answer = qa["answers"][0]
                        orig_answer_text = answer["text"]
                        answer_length = len(orig_answer_text)
                        start_index = answer["answer_start"]
                        end_index = start_index + answer_length - 1
                        answer["answer_end"] = end_index

                        actual_text = doc_tokens
                        cleaned_answer_text = tokenization.convert_to_unicode(orig_answer_text)
                        if actual_text.find(cleaned_answer_text) == -1:
                            print("Could not find answer: '%s' vs. '%s'",
                                              actual_text, cleaned_answer_text)
                            continue
                    else:
                        answer = {}
                        answer["answer_start"] = -1
                        answer["answer_end"] = -1
                        answer["text"] = ""
                    contexts.append(doc_tokens)
                    questions.append(question)
                    answers.append(answer)
                    tokenized_contexts.append(tokenized_doc_tokens)
                    tokenized_questions.append(tokenized_question)
                else:
                    all_answers_to_this_q = []
                    is_impossible = qa["is_impossible"]
                    if not is_impossible:
                        all_answers_to_this_q = [ans["text"] for ans in qa["answers"]]
                    else:
                        all_answers_to_this_q = []
                    contexts.append(context)
                    questions.append(question)
                    answers.append(all_answers_to_this_q)
                    tokenized_contexts.append(tokenized_context)
                    tokenized_questions.append(tokenized_question)
                
                q_i += 1
            p_i += 1
        g_i += 1

    return contexts, questions, answers, tokenized_contexts, tokenized_questions


train_contexts, train_questions, train_answers, tokenized_train_contexts, tokenized_train_questions = read_squad(train_path,tokenized_train_path,True)
val_contexts, val_questions, val_answers, tokenized_val_contexts, tokenized_val_questions = read_squad(dev_path,tokenized_dev_path,False)





In [None]:
def _check_is_max_context(doc_spans, cur_span_index, position):
  """Check if this is the 'max context' doc span for the token."""

  # Because of the sliding window approach taken to scoring documents, a single
  # token can appear in multiple documents. E.g.
  #  Doc: the man went to the store and bought a gallon of milk
  #  Span A: the man went to the
  #  Span B: to the store and bought
  #  Span C: and bought a gallon of
  #  ...
  #
  # Now the word 'bought' will have two scores from spans B and C. We only
  # want to consider the score with "maximum context", which we define as
  # the *minimum* of its left and right context (the *sum* of left and
  # right context will always be the same, of course).
  #
  # In the example the maximum context for 'bought' would be span C since
  # it has 1 left context and 3 right context, while span B has 4 left context
  # and 0 right context.
  best_score = None
  best_span_index = None
  for (span_index, doc_span) in enumerate(doc_spans):
    end = doc_span.start + doc_span.length - 1
    if position < doc_span.start:
      continue
    if position > end:
      continue
    num_left_context = position - doc_span.start
    num_right_context = end - position
    score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
    if best_score is None or score > best_score:
      best_score = score
      best_span_index = span_index

  return cur_span_index == best_span_index

In [None]:
def texts2encoding(contexts, questions, answers, tokenized_contexts, tokenized_questions, vocab_file, is_training, start_offset=0):
    input_ids = []
    attention_mask = []
    token_type_ids = []
    tokens = []
    token_to_orig_map = []
    token_is_max_context = []
    start_positions = []
    end_positions = []
    is_impossible = []
    eg_id = []

    for idx in tqdm(range(len(contexts))):
        for features in text2feature(contexts[idx],questions[idx],answers[idx],tokenized_contexts[idx],tokenized_questions[idx],vocab_file,is_training):
          if features:
              assert len(features)==9
              input_ids.append(features[0])
              attention_mask.append(features[1])
              token_type_ids.append(features[2])
              tokens.append(features[3])
              token_to_orig_map.append(features[4])
              token_is_max_context.append(features[5])
              start_positions.append(features[6])
              end_positions.append(features[7])
              is_impossible.append(features[8])
              eg_id.append(idx+start_offset)
          else:
              break
          
    tmp_dic = {'eg_id':eg_id,
               'input_ids':input_ids,
               'attention_mask':attention_mask,
               'token_type_ids':token_type_ids,
               'tokens':tokens,
               'token_to_orig_map':token_to_orig_map,
               'token_is_max_context':token_is_max_context,
               'start_positions':start_positions,
               'end_positions':end_positions,
               'is_impossible':is_impossible}
    encoding = BatchEncoding(tmp_dic)
    
    return encoding

def text2feature(context, question, answer, tokenized_context, tokenized_question, vocab_file, is_training, do_lower_case=True, max_seq_length=512, max_query_length=64):
    if is_training:
        is_impossible = (answer["text"]=="")
    else:
        is_impossible = (len(answer)==0)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
    
    query_tokens = tokenizer.tokenize(tokenized_question)
    if len(query_tokens) > max_query_length:
        query_tokens = query_tokens[0:max_query_length]
    
    tok_to_orig_index = []
    orig_to_tok_index = []
    tokens_for_test = []
    all_doc_tokens = tokenizer.tokenize(tokenized_context)

    for (i, token) in enumerate(all_doc_tokens):
        if token[0] == '▁':
            tok_to_orig_index.append(len(orig_to_tok_index))
            for n in range(len(all_doc_tokens[0])-1):
                orig_to_tok_index.append(n)
            tokens_for_test.append(token[1:])
        elif token == '[UNK]':
            tok_to_orig_index.append(len(orig_to_tok_index))
            orig_to_tok_index.append(i)
            tokens_for_test.append('.')
        else:
            tok_to_orig_index.append(len(orig_to_tok_index))
            for n in range(len(token)):
                orig_to_tok_index.append(i)
            tokens_for_test.append(token)
    if not len(context) == len(''.join(tokens_for_test)):
        yield None

    tok_start_position = None
    tok_end_position = None
    if is_training:
        if is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        else:
            tok_start_position = orig_to_tok_index[answer["answer_start"]]
            if answer["answer_end"] < len(orig_to_tok_index) - 1:
                tok_end_position = orig_to_tok_index[answer["answer_end"] + 1] - 1
            else:
                tok_end_position = orig_to_tok_index[-1]
    
    
    max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

    _DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
    doc_spans = []
    start_offset = 0
    doc_stride = 128
    while start_offset < len(all_doc_tokens):
        length = len(all_doc_tokens) - start_offset
        if length > max_tokens_for_doc:
            length = max_tokens_for_doc
        doc_spans.append(_DocSpan(start=start_offset, length=length))
        if start_offset + length == len(all_doc_tokens):
            break
        start_offset += min(length, doc_stride)

    for (doc_span_index, doc_span) in enumerate(doc_spans):
        tokens = []
        token_to_orig_map = {}
        token_is_max_context = {}
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in query_tokens:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        for i in range(doc_span.length):
            split_token_index = doc_span.start + i
            token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

            is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                  split_token_index)
            token_is_max_context[len(tokens)] = is_max_context
            tokens.append(all_doc_tokens[split_token_index])
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)
        
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        
        start_position = None
        end_position = None
        if is_training:
            if is_impossible:
                start_position = 0
                end_position = 0
            else:
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset
        
        yield input_ids,input_mask,segment_ids,tokens,token_to_orig_map,token_is_max_context,start_position,end_position,is_impossible



In [None]:
vocab_file = os.path.join(root_dir,'model/laboro_distilbert/tokenizer/ccc_13g_unigram_vocab_lower.txt')

train_encodings = texts2encoding(train_contexts, train_questions, train_answers, tokenized_train_contexts, tokenized_train_questions, vocab_file, True)
val_encodings = texts2encoding(val_contexts, val_questions, val_answers, tokenized_val_contexts, tokenized_val_questions,vocab_file,False)

train_encodings_path = os.path.join(ddqa_path, 'train_encodings.pickle')
val_encodings_path = os.path.join(ddqa_path, 'dev_encodings.pickle')

pickle.dump(train_encodings,open(train_encodings_path,'wb'))
pickle.dump(val_encodings,open(val_encodings_path,'wb'))


In [None]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items() if key in ["input_ids","attention_mask","token_type_ids","start_positions","end_positions"]}

    def __len__(self):
        return len(self.encodings.input_ids)

train_encodings = pickle.load(open(train_encodings_path,'rb'))
val_encodings = pickle.load(open(val_encodings_path,'rb'))
    
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

train_dataset_path = os.path.join(ddqa_path, 'train_dataset.pickle')
val_dataset_path = os.path.join(ddqa_path, 'dev_dataset.pickle')

pickle.dump(train_dataset,open(train_dataset_path,'wb'))
pickle.dump(val_dataset,open(val_dataset_path,'wb'))


In [None]:
train_dataset = pickle.load(open(train_dataset_path,'rb'))
val_dataset = pickle.load(open(val_dataset_path,'rb'))


In [None]:
# if OOM error occurs, rename config_ddqa.json as config.json for this task
# by doing this you'll turn off the output_hidden_states to save some memory


In [None]:
training_args = TrainingArguments(
    output_dir=os.path.join(root_dir,'model/laboro_distilbert/output_ddqa'),          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=100,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=os.path.join(root_dir,'model/laboro_distilbert/output_ddqa'),            # directory for storing logs
    logging_steps=1000,
)

model_path = os.path.join(root_dir,'model/laboro_distilbert')
model = DistilBertForQuestionAnswering.from_pretrained(model_path)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

In [None]:
training_args = TrainingArguments(
    output_dir=os.path.join(root_dir,'model/laboro_distilbert/output_ddqa'),          # output directory
    num_train_epochs=0,              # total number of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=os.path.join(root_dir,'model/laboro_distilbert/output_ddqa'),            # directory for storing logs
    logging_steps=1000,
)

model_path = os.path.join(root_dir,'model/laboro_distilbert/output_ddqa/checkpoint-1500')
model = DistilBertForQuestionAnswering.from_pretrained(model_path)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
)

predictions = trainer.predict(test_dataset=val_dataset).predictions


In [None]:
def get_feature_ids_by_eg_id(eg_idx,features,eg_start_offset):
  feature_ids = []
  _eg_idx = eg_idx+eg_start_offset
  for feature_id in range(len(features['eg_id'])):
    if features['eg_id'][feature_id]==_eg_idx:
      feature_ids.append(feature_id)
    elif features['eg_id'][feature_id]>_eg_idx:
      return feature_ids


In [None]:
def _compute_softmax(scores):
  """Compute softmax probability over raw logits."""
  if not scores:
    return []

  max_score = None
  for score in scores:
    if max_score is None or score > max_score:
      max_score = score

  exp_scores = []
  total_sum = 0.0
  for score in scores:
    x = math.exp(score - max_score)
    exp_scores.append(x)
    total_sum += x

  probs = []
  for score in exp_scores:
    probs.append(score / total_sum)
  return probs

In [None]:
def _get_best_indexes(logits, n_best_size):
  """Get the n-best logits from a list."""
  index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

  best_indexes = []
  for i in range(len(index_and_score)):
    if i >= n_best_size:
      break
    best_indexes.append(index_and_score[i][0])
  return best_indexes

def write_predictions(all_val_contexts, all_val_answers, all_features, all_results, eg_start_offset,
                      max_answer_length=30, do_lower_case=True, n_best_size=20):
  
  _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
      "PrelimPrediction",
      ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()

  correct_counter = 0
  for (example_index, val_contexts) in enumerate(all_val_contexts):
    val_answers = all_val_answers[example_index]
    
    feature_ids = get_feature_ids_by_eg_id(example_index,all_features,eg_start_offset)
    features = []
    results = []
    if not feature_ids:
      continue
    for f_id in feature_ids:
      feature = {}
      feature['tokens'] = all_features['tokens'][f_id]
      feature['token_to_orig_map'] = all_features['token_to_orig_map'][f_id]
      feature['token_is_max_context'] = all_features['token_is_max_context'][f_id]
      features.append(feature)
      result = {}
      result['start_logits'] = all_results[0][f_id]
      result['end_logits'] = all_results[1][f_id]
      results.append(result)

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive
    min_null_feature_index = 0  # the paragraph slice with min mull score
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score
    if len(features)==0:
      continue
    for (feature_index, feature) in enumerate(features):
      result = results[feature_index]

      start_indexes = _get_best_indexes(result['start_logits'], n_best_size)
      end_indexes = _get_best_indexes(result['end_logits'], n_best_size)

      feature_null_score = result['start_logits'][0] + result['end_logits'][0]
      if feature_null_score < score_null:
        score_null = feature_null_score
        min_null_feature_index = feature_index
        null_start_logit = result['start_logits'][0]
        null_end_logit = result['end_logits'][0]
      
      for start_index in start_indexes:
        for end_index in end_indexes:
          if start_index >= len(feature['tokens']):
            continue
          if end_index >= len(feature['tokens']):
            continue
          if start_index not in feature['token_to_orig_map']:
            continue
          if end_index not in feature['token_to_orig_map']:
            continue
          if not feature['token_is_max_context'].get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result['start_logits'][start_index],
                  end_logit=result['end_logits'][end_index]))
    prelim_predictions.append(
          _PrelimPrediction(
              feature_index=min_null_feature_index,
              start_index=0,
              end_index=0,
              start_logit=null_start_logit,
              end_logit=null_end_logit))
    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_logit", "end_logit"])

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        orig_doc_start = feature['token_to_orig_map'][pred.start_index]
        if pred.end_index+1 in feature['token_to_orig_map']:
          orig_doc_end = feature['token_to_orig_map'][pred.end_index+1]
        else:
          orig_doc_end = feature['token_to_orig_map'][pred.end_index] + len(feature['tokens'][-1])
        #print(pred.start_index,pred.end_index,orig_doc_start,orig_doc_end)
        orig_tokens = val_contexts[orig_doc_start:orig_doc_end]
        #print(orig_tokens)
        orig_text = "".join(orig_tokens)

        final_text = orig_text
        if final_text in seen_predictions:
          continue

        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True

      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))
    if "" not in seen_predictions:
      nbest.append(
          _NbestPrediction(
              text="", start_logit=null_start_logit,
              end_logit=null_end_logit))
    
    if not nbest:
      nbest.append(
          _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

    assert len(nbest) >= 1

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_logit + entry.end_logit)
      if not best_non_null_entry:
        if entry.text:
          best_non_null_entry = entry
    
    probs = _compute_softmax(total_scores)
    
    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_logit"] = entry.start_logit
      output["end_logit"] = entry.end_logit
      nbest_json.append(output)
    assert len(nbest_json) >= 1

    score_diff = score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit
    scores_diff_json[example_index] = score_diff
    
    if score_diff > 0:
      all_predictions[example_index] = ""
    else:
      all_predictions[example_index] = best_non_null_entry.text
    
    if all_predictions[example_index] in val_answers:
      correct_counter += 1
    
    '''
    else:
      print(all_predictions[example_index])
      print(val_answers)
    '''
  
  eval_res = {}
  eval_res['correct_answers'] = correct_counter
  eval_res['all_examples'] = len(all_val_contexts)
  eval_res['recall'] = correct_counter/len(all_val_contexts)
  print('correct_answers = {}'.format(eval_res['correct_answers']))
  print('all_examples = {}'.format(eval_res['all_examples']))
  print('recall = {}'.format( eval_res['recall']))


In [None]:
write_predictions(val_contexts, val_answers, val_encodings, predictions, eg_start_offset=0)
