# Libraries & Data

In [None]:
%%capture
! pip install transformers
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
! wget https://github.com/BigData-IsfahanUni/ParSQuAD/blob/main/Dataset/ParSQuAD.zip?raw=true -O ParSQuAD.zip
! unzip ParSQuAD.zip
! rm -r examples/train
! rm -r examples/eval
! rm -r features/train
! rm -r features/eval
! mkdir -p examples/train
! mkdir -p examples/eval
! mkdir -p features/train
! mkdir -p features/eval
! rm -r HistConcat/

In [None]:
import numpy as np
import torch
import json
import pickle
import unicodedata
from tqdm import tqdm
from copy import deepcopy
import transformers
from transformers.optimization import get_linear_schedule_with_warmup
from transformers import BertModel, BertForQuestionAnswering, AutoTokenizer, AutoModel
import os
from collections import defaultdict, namedtuple
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam, RMSprop
from copy import deepcopy
import random

In [None]:
def read_file(filename):
  with open(filename, 'r') as f:
    return json.load(f)

def load_data(filename):
  with open(filename, 'rb') as f:
    x = pickle.load(f)
  return x

def save_data(data, filename):
    with open(filename, "wb") as f:
        pickle.dump(data, f)

train_path = 'ParSQuAD/ParSQuAD-automatic-train.json'
eval_path = 'ParSQuAD/ParSQuAD-automatic-dev.json'
model_path_or_name = 'HooshvareLab/bert-fa-zwnj-base'

# load data
train_data = read_file(train_path)
eval_data = read_file(eval_path)

# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
model = AutoModel.from_pretrained(model_path_or_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/134 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/473M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at HooshvareLab/bert-fa-zwnj-base and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Official Evaluation Code

In [None]:
import json, string, re
from collections import Counter, defaultdict


def is_overlapping(x1, x2, y1, y2):
  return max(x1, y1) <= min(x2, y2)

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    return re.sub(r'\b(a|an|the)\b', ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
  prediction_tokens = normalize_answer(prediction).split()
  ground_truth_tokens = normalize_answer(ground_truth).split()
  common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  num_same = sum(common.values())
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(prediction_tokens)
  recall = 1.0 * num_same / len(ground_truth_tokens)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

def compute_span_overlap(pred_span, gt_span, text):
  if gt_span == 'غیرقابل‌پاسخ':
    if pred_span == 'غیرقابل‌پاسخ':
      return 'Exact match', 1.0
    return 'No overlap', 0.
  fscore = f1_score(pred_span, gt_span)
  pred_start = text.find(pred_span)
  gt_start = text.find(gt_span)

  if pred_start == -1 or gt_start == -1:
    return 'Span indexing error', fscore

  pred_end = pred_start + len(pred_span)
  gt_end = gt_start + len(gt_span)

  fscore = f1_score(pred_span, gt_span)
  overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end)

  if exact_match_score(pred_span, gt_span):
    return 'Exact match', fscore
  if overlap:
    return 'Partial overlap', fscore
  else:
    return 'No overlap', fscore

def exact_match_score(prediction, ground_truth):
  return (normalize_answer(prediction) == normalize_answer(ground_truth))

def display_counter(title, c, c2=None):
  print(title)
  for key, _ in c.most_common():
    if c2:
      print('%s: %d / %d, %.1f%%, F1: %.1f' % (
        key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()), sum(c2[key]) * 100. / len(c2[key])))
    else:
      print('%s: %d / %d, %.1f%%' % (key, c[key], sum(c.values()), c[key] * 100. / sum(c.values())))

def leave_one_out_max(prediction, ground_truths, article):
  if len(ground_truths) == 1:
    return metric_max_over_ground_truths(prediction, ground_truths, article)[1]
  else:
    t_f1 = []
    # leave out one ref every time
    for i in range(len(ground_truths)):
      idxes = list(range(len(ground_truths)))
      idxes.pop(i)
      refs = [ground_truths[z] for z in idxes]
      t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1])
  return 1.0 * sum(t_f1) / len(t_f1)


def metric_max_over_ground_truths(prediction, ground_truths, article):
  scores_for_ground_truths = []
  for ground_truth in ground_truths:
    score = compute_span_overlap(prediction, ground_truth, article)
    scores_for_ground_truths.append(score)
  return max(scores_for_ground_truths, key=lambda x: x[1])


def handle_cannot(refs):
  num_cannot = 0
  num_spans = 0
  for ref in refs:
    if ref == 'غیرقابل‌پاسخ':
      num_cannot += 1
    else:
      num_spans += 1
  if num_cannot >= num_spans:
    refs = ['CANNOTANSWER']
  else:
    refs = [x for x in refs if x != 'غیرقابل‌پاسخ']
  return refs


def leave_one_out(refs):
  if len(refs) == 1:
    return 1.
  splits = []
  for r in refs:
    splits.append(r.split())
  t_f1 = 0.0
  for i in range(len(refs)):
    m_f1 = 0
    for j in range(len(refs)):
      if i == j:
        continue
      f1_ij = f1_score(refs[i], refs[j])
      if f1_ij > m_f1:
        m_f1 = f1_ij
    t_f1 += m_f1
  return t_f1 / len(refs)






def eval_fn(val_results, model_results, verbose):
  span_overlap_stats = Counter()
  sentence_overlap = 0.
  para_overlap = 0.
  total_qs = 0.
  f1_stats = defaultdict(list)
  unfiltered_f1s = []
  total_dials = 0.
  unanswerables = []
  for p in val_results:
    for par in p['paragraphs']:
      did = par['id']
      qa_list = par['qas']
      good_dial = 1.
      for qa in qa_list:
        q_idx = qa['id']
        val_spans = [anss['text'] for anss in qa['answers']]
        val_spans = handle_cannot(val_spans)
        hf1 = leave_one_out(val_spans)

        if did not in model_results or q_idx not in model_results[did]:
          # print(did, q_idx, 'no prediction for this dialogue id')
          good_dial = 0
          f1_stats['NO ANSWER'].append(0.0)
          if val_spans == ['غیرقابل‌پاسخ']:
            unanswerables.append(0.0)
          total_qs += 1
          unfiltered_f1s.append(0.0)
          if hf1 >= .4:
            human_f1.append(hf1)
          continue

        pred_span, pred_yesno, pred_followup = model_results[did][q_idx]

        max_overlap, _ = metric_max_over_ground_truths( \
          pred_span, val_spans, par['context'])
        max_f1 = leave_one_out_max( \
          pred_span, val_spans, par['context'])
        unfiltered_f1s.append(max_f1)

        # dont eval on low agreement instances
        if hf1 < .4:
          continue

        human_f1.append(hf1)

        if val_spans == ['غیرقابل‌پاسخ']:
          unanswerables.append(max_f1)
        if verbose:
          print("-" * 20)
          print(pred_span)
          print(val_spans)
          print(max_f1)
          print("-" * 20)
        if max_f1 >= hf1:
          HEQ += 1.
        else:
          good_dial = 0.
        span_overlap_stats[max_overlap] += 1
        f1_stats[max_overlap].append(max_f1)
        total_qs += 1.
      DHEQ += good_dial
      total_dials += 1


  DHEQ_score = 100.0 * DHEQ / total_dials
  HEQ_score = 100.0 * HEQ / total_qs
  all_f1s = sum(f1_stats.values(), [])
  overall_f1 = 100.0 * sum(all_f1s) / len(all_f1s)
  unfiltered_f1 = 100.0 * sum(unfiltered_f1s) / len(unfiltered_f1s)
  unanswerable_score = (100.0 * sum(unanswerables) / len(unanswerables))
  metric_json = {"unfiltered_f1": unfiltered_f1, "f1": overall_f1, "HEQ": HEQ_score, "DHEQ": DHEQ_score, "yes/no": yesno_score, "followup": followup_score, "unanswerable_acc": unanswerable_score}
  if verbose:
    print("=======================")
    display_counter('Overlap Stats', span_overlap_stats, f1_stats)
  print("=======================")
  print('Overall F1: %.1f' % overall_f1)
  with open('val_report.txt', 'a') as f:
    f.write('Overall F1: %.1f' % overall_f1)

  print('Unfiltered F1 ({0:d} questions): {1:.1f}'.format(len(unfiltered_f1s), unfiltered_f1))
  print('Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)'.format(unanswerable_score, len(unanswerables)))
  print('Human F1: %.1f' % (100.0 * sum(human_f1) / len(human_f1)))
  print('Model F1 >= Human F1 (Questions): %d / %d, %.1f%%' % (HEQ, total_qs, 100.0 * HEQ / total_qs))
  print('Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials))
  print("=======================")
  output_string = 'Overall F1: %.1f\n' % overall_f1
  output_string += 'Yes/No Accuracy : %.1f\n' % yesno_score
  output_string += 'Followup Accuracy : %.1f\n' % followup_score
  output_string += 'Unfiltered F1 ({0:d} questions): {1:.1f}\n'.format(len(unfiltered_f1s), unfiltered_f1)
  output_string += 'Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)\n'.format(unanswerable_score, len(unanswerables))
  output_string += 'Human F1: %.1f\n' % (100.0 * sum(human_f1) / len(human_f1))
  output_string += 'Model F1 >= Human F1 (Questions): %d / %d, %.1f%%\n' % (HEQ, total_qs, 100.0 * HEQ / total_qs)
  output_string += 'Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials)

  # save_prediction(epoch, train_step, output_string)

  return metric_json

def run_eval():
  new_eval_data = dict()
  for data in eval_data['data']:
    for d in data['paragraphs']:
      for qa in d['qas']:
        new_eval_data[qa['id']] = dict()
        new_eval_data[qa['id']]['answers'] = [a['text'] for a in qa['answers']]
        new_eval_data[qa['id']]['context'] = d['context']
  f1s = []
  for qid, model_answer in eval_p.answers.items():
    orig_answers = new_eval_data[qid]['answers']
    context = new_eval_data[qid]['context']
    f1s_ = []
    for orig_answer in orig_answers:
      f1 = compute_span_overlap(model_answer, orig_answer, context)[1]
      f1s_.append(f1)
    f1s.append(max(f1s_))
  f1_score_ = sum(f1s) / len(f1s)
  print('f1 is', f1_score_)

In [None]:
train_data = read_file('ParSQuAD/ParSQuAD-automatic-train.json')
eval_data = read_file('ParSQuAD/ParSQuAD-automatic-dev.json')

In [None]:
# eval_data['data'][0]['paragraphs'][0]['qas']

# Preprocess Data

In [None]:
class QA_DATA:

  def __init__(self,
               question,
               context,
               answer,
               qid,
               q_num,
               answer_start,
               answer_end,
               is_answerable):

    self.question = question
    self.context = context
    self.answer = answer
    self.qid = qid
    self.q_num = q_num
    self.answer_start = answer_start
    self.answer_end = answer_end
    self.is_answerable = is_answerable
    self.cleaned_context = None
    self.answer = self.answer['text']
    self.cleaned_answer = answer
    self.cleaned_context = context

  def __repr__(self):
    repr = ''
    repr += 'context -> ' + self.context[:100] + '\n'
    repr += 'question ->' + self.question + '\n'
    repr += 'question id ->' + str(self.qid) + '\n'
    repr += 'turn_number ->' + str(self.turn_number) + '\n'
    repr += 'answer ->' + self.answers[0]['text'] + '\n'
    return repr

In [None]:
class Feature:

  def __init__(self,
               qid,
               question_part,
               input_ids,
               attention_mask,
               token_type_ids,
               offset_mappings,
               max_context_dict,
               start,
               end,
               is_answerable,
               context,
               cleaned_context,
               context_start,
               context_end,
               example_start_char,
               example_end_char,
               example_answer):

    self.qid = qid
    self.question_part = question_part
    self.input_ids = input_ids
    self.attention_mask = attention_mask
    self.token_type_ids = token_type_ids
    self.offset_mappings = offset_mappings
    self.max_context_dict = max_context_dict
    self.start = start
    self.end = end
    self.is_answerable = is_answerable
    self.context = context
    self.cleaned_context = cleaned_context
    self.context_start = context_start
    self.context_end = context_end
    self.example_start_char = example_start_char
    self.example_end_char = example_end_char
    self.example_answer = example_answer

  def __repr__(self):
    repr = ''
    repr += 'qid --> ' + str(self.qid) + '\n'
    repr += 'quesion part --> ' + str(self.question_part) + '\n'
    repr += 'answer part --> ' + str(self.start) + ' ' + str(self.end) + '\n'
    return repr

# Examples

In [None]:
def make_examples(datas, data_type):
  examples = []
  each_file_size = 5000
  example_file_index = 0
  data_dir = f'examples/{data_type}/'


  for data_id, data in tqdm(enumerate(datas['data'])):

    for par in data['paragraphs']:

        qas = par['qas']
        context = par['context'] + ' غیرقابل‌پاسخ'

        for q_num, qa in enumerate(qas):
          qid = qa['id']
          question = qa['question']

          is_answerable = False if qa['is_impossible'] == True else True

          answer = qa['answers']
          if len(answer) != 0:
            answer = answer[0]
          else:
            continue

          if qa['is_impossible'] == True:
            print('yes')

          if not is_answerable:
            print(answer['text'])
            ans_start = context.find('غیرقابل‌پاسخ')
            answer_ = {
                'text': 'غیرقابل‌پاسخ',
                'start': ans_start,
                'end': ans_start + len('غیرقابل‌پاسخ')
          }
          else:
            answer_ = {
                'text': answer['text'],
                'start': answer['answer_start'],
                'end': answer['answer_start'] + len(answer['text'])
            }

          qa_example = QA_DATA(question=question,
                                context=context,
                                answer=answer_,
                                qid=qid,
                                q_num=q_num,
                                answer_start=answer_['start'],
                                answer_end=answer_['end'],
                                is_answerable=is_answerable)

          examples.append(qa_example)

          if len(examples) % each_file_size == 0:
            filename = f'{data_type}_examples_' + str(example_file_index) + '.bin'
            save_data(examples, os.path.join(data_dir, filename))
            example_file_index += 1
            examples = []

  if examples != []:
    filename = f'{data_type}_examples_' + str(example_file_index) + '.bin'
    save_data(examples, os.path.join(data_dir, filename))

# Features

In [None]:
def make_features(data_type):
  data_dir = f'examples/{data_type}/'
  example_files = os.listdir(data_dir)
  example_files = [os.path.join(data_dir, example_file) for example_file in example_files]
  features_list = []
  features_dir = f'features/{data_type}/'
  current_file_index = 0
  max_history_to_consider = 1

  for file_index, filename in enumerate(example_files):
    examples = load_data(filename)
    for example in tqdm(examples, leave=False, position=0):
      example_features = []
      concatenated_question = []


      # append current question to concatenated question
      concatenated_question.append(example.question)

      # make string out of concatenated question
      concatenated_question = ' '.join(concatenated_question)

      # tokenize current feature
      text_tokens = tokenizer(
          concatenated_question,
          example.cleaned_context,
          max_length=model.config.max_position_embeddings,
          padding='max_length',
          truncation='only_second',
          return_overflowing_tokens=True,
          return_offsets_mapping=True,
          stride=128)

      # find start and end of context
      for idx in range(len(text_tokens['input_ids'])):
        found_start = False
        found_end = False
        context_start = 0
        cintext_end = 511
        max_context_dict = {}

        for token_idx, token in enumerate(text_tokens['offset_mapping'][idx][1:]):
          if token[0] == 0 and token[1] == 0:
            context_start = token_idx + 2
            break

        for token_idx, token in enumerate(text_tokens['offset_mapping'][idx][context_start:]):
          if token[0] == 0 and token[1] == 0:
            context_end = token_idx + context_start - 1
            break

        chunk_offset_mapping = text_tokens['offset_mapping'][idx]
        for context_idx, data in enumerate(chunk_offset_mapping[context_start: context_end + 1]):
          max_context_dict[f'({data[0]},{data[1]})'] = min(context_idx, context_end - context_idx) + (context_end - context_start + 1) * .01

        # find and mark current question answer
        marker_ids = np.zeros(shape=(model.config.max_position_embeddings,), dtype=np.int64)
        last_token = None
        for token_idx, token in enumerate(chunk_offset_mapping[context_start: context_end + 1]):
          if token[0] == example.cleaned_answer['start'] and not found_start:
            found_start = True
            start = token_idx + context_start

          elif last_token and last_token[0] < example.cleaned_answer['start'] and token[0] > example.cleaned_answer['start']:
            found_start = True
            start = (token_idx - 1) + context_start

          if token[1] == example.cleaned_answer['end'] and not found_end:
            found_end = True
            end = token_idx + context_start

          elif last_token and last_token[1] < example.cleaned_answer['end'] and token[1] > example.cleaned_answer['end'] and last_token:
            found_end = True
            end = token_idx + context_start
          last_token = token

        # add feature to features list
        if found_start and found_end and end < start:
          assert False, 'start and end do not match'

        # since there is no prediction we throw the example out (only when training)
        if ((not found_start) or (not found_end)) and data_type == 'train':
          continue

        # if ((not found_start) or (not found_end)) and data_type == 'train':
        #   continue
        #   start, end = 0, 0
        #   if example.is_answerable == False:
        #     print(example.answers[0]['text'])

        # plausibility check
        if found_start or found_end:
          answer = example.cleaned_answer['text'].strip()
          generated_answer = example.cleaned_context[chunk_offset_mapping[start][0]: chunk_offset_mapping[end][1]]
          if answer.find(generated_answer) == -1:
            pass

        # mark history answers

        example_features.append(Feature(example.qid,
                                          idx,
                                          text_tokens['input_ids'][idx],
                                          text_tokens['attention_mask'][idx],
                                          text_tokens['token_type_ids'][idx],
                                          text_tokens['offset_mapping'][idx],
                                          max_context_dict,
                                          start,
                                          end,
                                          example.is_answerable,
                                          example.context,
                                          example.cleaned_context,
                                          context_start,
                                          context_end,
                                          example.answer_start,
                                          example.answer_end,
                                          example.answer))
      # create max context mask
      for feature_1 in example_features:
        max_context_mask = {}
        for key in list(feature_1.max_context_dict.keys()):
          max_context_mask[key] = True
          for feature_2 in example_features:
            if key in feature_2.max_context_dict:
              if feature_1.max_context_dict[key] < feature_2.max_context_dict[key]:
                max_context_mask[key] = False
        feature_1.max_context_mask = max_context_mask

        found_start = found_end = False
        start_mask = end_mask = 0
        # now compute span mask
        for key_idx, (key, value) in enumerate(feature_1.max_context_mask.items()):
          if key_idx == 0 and value:
            found_start = True
          elif value and not found_start:
            start_mask = key_idx
            found_start = True
          elif not value and found_start and not found_end:
            end_mask = key_idx
            found_end = True
          elif key_idx == len(feature_1.max_context_mask) - 1 and value and not found_end:
            end_mask = key_idx + 1
        feature_1.mask_span = [context_start + start_mask, context_start + end_mask]
      features_list.extend(example_features)

    filename = f'{data_type}_features_' + str(file_index) + '.bin'
    save_data(features_list, os.path.join(features_dir, filename))
    features_list = []

In [None]:
make_examples(train_data, 'train')
make_examples(eval_data, 'eval')
make_features('train')
make_features('eval')

442it [00:00, 672.54it/s]
35it [00:00, 2178.92it/s]


In [None]:
A = load_data('examples/eval/eval_examples_0.bin')
eval_data_ans = dict()
for a in A:
  eval_data_ans[a.qid] = a.answer

In [None]:
eval_data_ans.keys()

dict_keys(['56ddde6b9a695914005b9628', '56ddde6b9a695914005b9629', '56ddde6b9a695914005b962a', '56dddf4066d3e219004dad5f', '56dddf4066d3e219004dad60', '56dddf4066d3e219004dad61', '56dde0379a695914005b9637', '56dde0ba66d3e219004dad75', '56dde0ba66d3e219004dad76', '56dde27d9a695914005b9651', '56dde27d9a695914005b9652', '56dde2fa66d3e219004dad9b', '56de0ffd4396321400ee258d', '56de0ffd4396321400ee258e', '56de10b44396321400ee2593', '56de10b44396321400ee2594', '56de10b44396321400ee2595', '56de148dcffd8e1900b4b5bd', '56de15104396321400ee25b7', '56de15104396321400ee25b8', '56de15104396321400ee25b9', '56de1563cffd8e1900b4b5c2', '56de1563cffd8e1900b4b5c4', '56de15dbcffd8e1900b4b5ca', '56de1645cffd8e1900b4b5d1', '56de16ca4396321400ee25c5', '56de16ca4396321400ee25c7', '56de16ca4396321400ee25c8', '56de1728cffd8e1900b4b5d7', '56de179dcffd8e1900b4b5da', '56de179dcffd8e1900b4b5db', '56de179dcffd8e1900b4b5dc', '56de17f9cffd8e1900b4b5e0', '56de17f9cffd8e1900b4b5e2', '56de17f9cffd8e1900b4b5e3', '56de3cd0

# DataLoader

In [None]:
class DataManager:

  def __init__(self, current_file, current_index, data_dir, batch_size, shuffle=True):
    self.files = sorted(os.listdir(data_dir), key=lambda x: int(x.split('_')[2].split('.')[0]))
    self.files = list(map(lambda x: os.path.join(data_dir, x), self.files))
    self.shuffle = shuffle
    self.data_len = 0
    for filename in self.files:
      self.data_len += len(load_data(filename))
    self.batch_size = batch_size
    self.reset_datamanager(current_file, current_index)

  def reset_datamanager(self, current_file_index, current_index):
    self.current_index = current_index
    self.current_file_index = current_file_index
    self.features = self.load_data_file(self.files[self.current_file_index])

  def load_data_file(self, filename):
    if self.shuffle:
      data = load_data(filename)
      random.shuffle(data)
      return data
    else:
      return load_data(filename)

  def next(self):
    temp = self.features[self.current_index:self.current_index + self.batch_size]
    self.temp = temp
    self.current_index += self.batch_size
    if self.current_index >= len(self.features):
      self.current_index = 0
      self.current_file_index += 1
      if self.current_file_index == len(self.files):
        self.reset_datamanager(current_file_index=0, current_index=0)
        return temp, True
      else:
        self.features = self.load_data_file(self.files[self.current_file_index])
    return temp, False

In [None]:
class DataLoader:

  def __init__(self, current_file, current_index, batch_size, shuffle=True, training=True):
    data_type = 'train' if training else 'eval'
    self.batch_size = batch_size
    self.data_manager = DataManager(current_file, current_index, f'features/{data_type}/', batch_size, shuffle)

  def __iter__(self):
    self.stop_iteration = False
    return self

  def __len__(self):
    return int(self.data_manager.data_len // self.batch_size)

  def reset_dataloader(self, current_file, current_index):
    self.data_manager.reset_datamanager(current_file, current_index)

  def features_2_tensor(self, features):
    x = dict()
    x['input_ids'] = torch.LongTensor([feature.input_ids for feature in features])
    x['attention_mask'] = torch.LongTensor([feature.attention_mask for feature in features])
    x['token_type_ids'] = torch.LongTensor([feature.token_type_ids for feature in features])
    x['start_positions'] = torch.cat([torch.tensor([feature.start]) for feature in features]).view(-1)
    x['end_positions'] = torch.cat([torch.tensor([feature.end]) for feature in features]).view(-1)
    x['features'] = features
    return x

  def __next__(self):
    if self.stop_iteration:
      raise StopIteration
    features, self.stop_iteration = self.data_manager.next()
    return self.features_2_tensor(features)

# Utils

In [None]:
A = load_data('features/eval/eval_features_0.bin')

In [None]:
feature_output = namedtuple(
    'feature_output',
        ['start_logit', 'end_logit', 'feature'])

PrelimPrediction = namedtuple(
    "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "qid"]
    )

NbestPrediction = namedtuple(
    "NbestPrediction", ["text", "start_logit", "end_logit"]
)

Answer = namedtuple(
    'Answer', ['qid', 'answer']
)

def to_numpy(tensor):
  return tensor.detach().cpu().numpy()

class EvalProcessOutput:
  def __init__(self, n_best_size=4, answer_max_len=40, answerability_threshold=0.0):
    self.answers = defaultdict(list)
    self.examples_output = []
    self.n_best_size = n_best_size
    self.answer_max_len = answer_max_len
    self.answerability_threshold = answerability_threshold
    self.ps = []


  def process_feature_output(self, start_logits, end_logits, features):
    for start_logit, end_logit, feature in zip(start_logits, end_logits, features):
      self.examples_output.append(
          feature_output(start_logit, end_logit, feature)
      )

  def stack_features(self):
    examples_list = defaultdict(list)
    for feature_out in self.examples_output:
      examples_list[feature_out.feature.qid].append(feature_out)
    return examples_list


  def process_output(self):
    self.extract_answers()
    self.get_predictions()

  def get_predictions(self):
    dialogs = defaultdict(list)
    self.dialogs_answers = defaultdict(list)
    for example_qid, answer in self.answers.items():
      dialog_id = example_qid
      dialogs[dialog_id].append(Answer(example_qid, answer))
    self.dialogs_answers = dialogs


  def extract_answers(self):
    examples_list = self.stack_features()
    for example_qid, example in examples_list.items():
      null_score = np.inf
      prelim_predictions = []
      self.example = example
      for feature_index, feature_output in enumerate(example):
        feature_null_score = feature_output.start_logit[0] + feature_output.end_logit[0]

        if feature_null_score < null_score:
          null_score = feature_null_score
          null_feature_index = feature_index
          null_start_logit = feature_output.start_logit[0]
          null_end_logit = feature_output.end_logit[0]

        start_indexes = self.get_best_indexes(feature_output.start_logit)
        end_indexes = self.get_best_indexes(feature_output.end_logit)

        for start_index in start_indexes:
          for end_index in end_indexes:
            if start_index > feature_output.feature.context_end:
              continue
            # if end_index > feature_output.feature.context_end:
            #   continue
            # if start_index < feature_output.feature.context_start:
            #   continue
            if end_index < feature_output.feature.context_start:
              continue
            if start_index < feature_output.feature.mask_span[0]:
              continue
            if end_index - start_index + 1 > self.answer_max_len:
              continue
            if end_index <= start_index:
              continue

            prelim_predictions.append(
                PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=feature_output.start_logit[start_index],
                  end_logit=feature_output.end_logit[end_index],
                  qid=example_qid
            )
                )
      # append a null one for handling CANNOTANSWER
      prelim_predictions.append(
        PrelimPrediction(
          feature_index=null_feature_index,
          start_index=0,
          end_index=0,
          start_logit=null_start_logit,
          end_logit=null_end_logit,
          qid=example_qid
      )
        )
      prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
      self.t = prelim_predictions
      # print(ff)
      best_pred = prelim_predictions[0]
      is_answerable = null_score - (best_pred.start_logit + best_pred.end_logit) <= self.answerability_threshold
      if is_answerable:
        feature = example[best_pred.feature_index].feature
        start_char = feature.offset_mappings[best_pred.start_index][0]
        end_char = feature.offset_mappings[best_pred.end_index][1]
        answer = feature.cleaned_context[start_char: end_char + 1]
        # answer = self.improve_answer_quality(answer)
      else:
        answer = 'غیرقابل‌پاسخ'

      self.answers[example_qid] = answer


  def get_best_indexes(self, logits):
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= self.n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

# Model

In [None]:
class BertHAE(nn.Module):

  def __init__(self, bert, device):
    super(BertHAE, self).__init__()
    self.transformer = bert
    self.start_end_head = nn.Linear(self.transformer.config.hidden_size, 2)
    nn.init.normal_(self.start_end_head.weight, mean=.0, std=.02)
    self.device = device

  def forward(self, x):
    for key in x:
      x[key] = x[key].to(device)
    # transformer output
    transformer_output = self.transformer(**x)
    start_end_logits = self.start_end_head(transformer_output.last_hidden_state)
    start_logits, end_logits = start_end_logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)
    return start_logits, end_logits

# Saving Settings

In [None]:
drive_prefix = 'drive/MyDrive/Parsbert_P_ParSQuAD/'
drive_checkpoint_dir = 'Checkpoint/'
drive_log_dir = 'Log/'
checkpoint_dir = os.path.join(drive_prefix, drive_checkpoint_dir)
log_dir = os.path.join(drive_prefix, drive_log_dir)

meta_log_file = os.path.join(drive_prefix, drive_log_dir, 'test.txt')
prediction_file_prefix = os.path.join(drive_prefix, drive_log_dir, 'prediction_')
loss_log_file = os.path.join(drive_prefix, drive_log_dir, 'loss.txt')
mean_f1_file = os.path.join(drive_prefix, drive_log_dir, 'mean_f1.txt')

if not os.path.exists(drive_prefix):
  os.mkdir(drive_prefix)
  print('crated saved dir')
if not os.path.exists(checkpoint_dir):
  os.mkdir(checkpoint_dir)
if not os.path.exists(log_dir):
  os.mkdir(log_dir)

with open(meta_log_file, 'w') as f:
  pass
# check if drive is accessible
try:
   with open(os.path.join(drive_prefix, drive_log_dir, 'test.txt'), 'r') as f:
      pass
except:
  print('No Access to Drive')
  exit()

crated saved dir


In [None]:
! cp drive/MyDrive/checkpoint_1_0_0 HistConcat/Checkpoint/

cp: cannot stat 'drive/MyDrive/checkpoint_1_0_0': No such file or directory


In [None]:
def print_loss(loss_collection, epoch, step):
  txt = f'EPOCH [{epoch + 1}/{epochs}] | STEP [{step}/{int(len(train_dataloader) / accumulation_steps)}] | Loss {round(sum(loss_collection) / len(loss_collection), 4)}'
  print(txt)

def save_loss(loss_collection, epoch, step):
  txt = f'EPOCH [{epoch + 1}/{epochs}] | STEP [{step}/{int(len(train_dataloader) / accumulation_steps)}] | Loss {round(sum(loss_collection) / len(loss_collection), 4)}'
  with open(loss_log_file, 'a') as f:
    f.write(txt)
    f.write('\n')



# check the checkpoints drive
checkpoint_files = os.listdir(os.path.join(drive_prefix, drive_checkpoint_dir))
if len(checkpoint_files) == 0:
  checkpoint_available = False
  print('No checkpoint found, training from begining')
else:
  checkpoint_available = True
  assert len(checkpoint_files) >= 1, 'Checkpoints are messed up'

if checkpoint_available:
  current_checkpoint = sorted(checkpoint_files, key=lambda x: [int(x.split('_')[1]), int(x.split('_')[2]), int(x.split('_')[3])])[-1]
  print('checkpoints to load ', current_checkpoint)
  current_checkpoint = os.path.join(drive_prefix, drive_checkpoint_dir, current_checkpoint)


def save_prediction(epoch, step, prediction_log):
  with open(os.path.join(drive_prefix, drive_log_dir, 'prediction.txt'), 'a') as f:
    f.write(f'--------- EPOCH {epoch} STEP {step} ---------\n')
    f.write(prediction_log)
    f.write('\n')
    f.write('\n')

def save_checkpoint(epoch, current_file, current_index):
  filename_prefix = os.path.join(drive_prefix, drive_checkpoint_dir, f'checkpoint_{epoch}_{current_file}_{current_index}')
  checkpoint_config = {
  'epoch': epoch,
  'step': train_step,
  'optimizer_dict': optimizer.state_dict(),
  'scheduler_dict': scheduler.state_dict(),
  'model_dict': berthae.state_dict(),
  'train_current_file': current_file,
  'train_current_index': current_index}
  torch.save(checkpoint_config, filename_prefix)

def load_checkpoint():
    # models have been loaded before so no need to load them again
    checkpoint_config = torch.load(current_checkpoint)
    return (checkpoint_config['epoch'],
            checkpoint_config['step'],
            checkpoint_config['optimizer_dict'],
            checkpoint_config['scheduler_dict'],
            checkpoint_config['model_dict'],
            checkpoint_config['train_current_file'],
            checkpoint_config['train_current_index'])

# Train loop

In [None]:
epochs = 0
lr = 3e-5
beta_1 = .9
beta_2 = .999
eps = 1e-6
batch_size = 10
accumulation_steps = 1
accumulation_counter = 0

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
berthae = BertHAE(deepcopy(model), device).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)

loss_collection = []
train_dataloader = DataLoader(current_file=0, current_index=0, batch_size=batch_size, shuffle=True, training=True)
eval_dataloader = DataLoader(current_file=0, current_index=0, batch_size=1, shuffle=False, training=False)
each_step_log = 100
start_epoch = 0
start_step = 0
current_file = 0
current_index = 0


optimization_steps = int(epochs * len(train_dataloader) / accumulation_steps)
warmup_ratio = .1
warmup_steps = int(optimization_steps * warmup_ratio)

optimizer = AdamW(berthae.parameters(), lr=lr, betas=(beta_1,beta_2), eps=eps)
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=optimization_steps)

# laod checkpoint if available
if checkpoint_available:
  print('loading checkpoint')
  start_epoch, start_step, optimizer_dict, scheduler_dict, berthae_dict, current_file, current_index = load_checkpoint()
  # load state dicts
  berthae.load_state_dict(berthae_dict)
  optimizer.load_state_dict(optimizer_dict)
  scheduler.load_state_dict(scheduler_dict)

current_file_index_ = current_file
train_dataloader.reset_dataloader(current_file, current_index)
berthae.train()
for epoch in range(start_epoch, epochs):
  train_step = 0
  acc_loss = 0
  log_step = 0

  for data in train_dataloader:
    if train_dataloader.data_manager.current_file_index != current_file_index_:
      current_file_index_ = train_dataloader.data_manager.current_file_index
      print(current_file_index_)
      print('-------------')
      if int(current_file_index_) % 3 == 0 and current_file_index_ != current_file:
        save_checkpoint(epoch, current_file_index_, 0)
    start_positions = data.pop('start_positions').to(device)
    end_positions = data.pop('end_positions').to(device)
    features = data.pop('features')
    start_logits, end_logits = berthae(data)
    loss = (loss_fn(start_logits, start_positions) + loss_fn(end_logits, end_positions)) / 2
    loss = loss / accumulation_steps
    acc_loss += loss.item()
    loss.backward()

    if len(loss_collection) % each_step_log == 0 and len(loss_collection) != 0:
      print_loss(loss_collection, epoch, log_step + 1)
      save_loss(loss_collection, epoch, log_step + 1)
      loss_collection = []


    accumulation_counter += 1
    if accumulation_counter % accumulation_steps == 0:
      loss_collection.append(acc_loss)
      acc_loss = 0
      log_step += 1
      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      torch.cuda.empty_cache()
      accumulation_counter = 0

    train_step += 1

  save_checkpoint(epoch + 1, 0, 0)
  berthae.eval()
  print('-------------------- Evaluation --------------------')
  eval_p = EvalProcessOutput()
  with torch.no_grad():
    for step, data in enumerate(eval_dataloader):
      start_positions = data.pop('start_positions')
      end_positions = data.pop('end_positions')
      features = data.pop('features')
      start_logits, end_logits = berthae(data)
      eval_p.process_feature_output(to_numpy(start_logits),
                                    to_numpy(end_logits),
                                    features)

  eval_p.process_output()
  run_eval()
  berthae.train()

In [None]:
  berthae.eval()
  print('-------------------- Evaluation --------------------')
  eval_p = EvalProcessOutput()
  with torch.no_grad():
    for step, data in enumerate(eval_dataloader):
      start_positions = data.pop('start_positions')
      end_positions = data.pop('end_positions')
      features = data.pop('features')
      start_logits, end_logits = berthae(data)
      eval_p.process_feature_output(to_numpy(start_logits),
                                    to_numpy(end_logits),
                                    features)

  eval_p.process_output()
  run_eval()
  berthae.train()