In [1]:
%%capture
! pip install hazm
! pip install accelerate
! pip install bitsandbytes
! pip install gdown
from hazm import *
import gdown
import pickle as pk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
import pickle
import shutil
import random
import unicodedata
from tqdm import tqdm
from copy import deepcopy
import os
from collections import defaultdict, namedtuple
import transformers
from transformers.optimization import get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam, RMSprop
import numpy as np

def make_dir(dir_name):
  if not os.path.exists(dir_name):
    os.mkdir(dir_name)

def read_file(filename):
  with open(filename, 'r') as f:
    return json.load(f)

def load_data(filename):
  with open(filename, 'rb') as f:
    x = pickle.load(f)
  return x

def save_data(data, filename):
    with open(filename, "wb") as f:
        pickle.dump(data, f)

def send_2_device(tokens):
  new_tokens = dict()
  for k, v in tokens.items():
    new_tokens[k] = v.to(device)
  return new_tokens

def to_numpy(tensor):
  return tensor.detach().cpu().numpy()

def get_prob(output):
  scores = [score.view(-1) for score in output.scores]
  sequences = output.sequences.view(-1)
  log_probs = []
  for score, seq in zip(scores, sequences[1: ]):
    log_prob = F.log_softmax(score, dim=0)
    log_probs.append(log_prob[seq])
  log_probs = torch.stack(log_probs).mean().cpu().item()
  return log_probs

gdown.download(id="1tzaAHUkIkGzhbZpCIKmOYYQKvvqoRfsO")
gdown.download(id="16wPRHP2AC5WI2m7Y_fEWEOUYl7ynMKrb")
gdown.download(id="1qYU3601tCOI-MTQut8Nb7mSZMDLXuASY")

make_dir('examples')
make_dir('features')
make_dir('examples/train')
make_dir('examples/eval')
make_dir('examples/test')
make_dir('features/train')
make_dir('features/eval')
make_dir('features/test')


train_data = load_data('PCoQA_Train.pk')
eval_data = load_data('PCoQA_Eval.pk')
test_data = load_data('PCoQA_Test.pk')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Eval Code

In [2]:
"""# Official Evaluation Code"""

import json, string, re
from collections import Counter, defaultdict


def is_overlapping(x1, x2, y1, y2):
  return max(x1, y1) <= min(x2, y2)

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    return re.sub(r'\b(a|an|the)\b', ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
  prediction_tokens = normalize_answer(prediction).split()
  ground_truth_tokens = normalize_answer(ground_truth).split()
  common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
  num_same = sum(common.values())
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(prediction_tokens)
  recall = 1.0 * num_same / len(ground_truth_tokens)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

def compute_span_overlap(pred_span, gt_span, text):
  if gt_span == 'غیرقابل‌پاسخ':
    if pred_span == 'غیرقابل‌پاسخ':
      return 'Exact match', 1.0
    return 'No overlap', 0.
  fscore = f1_score(pred_span, gt_span)
  pred_start = text.find(pred_span)
  gt_start = text.find(gt_span)

  if pred_start == -1 or gt_start == -1:
    return 'Span indexing error', fscore

  pred_end = pred_start + len(pred_span)
  gt_end = gt_start + len(gt_span)

  fscore = f1_score(pred_span, gt_span)
  overlap = is_overlapping(pred_start, pred_end, gt_start, gt_end)

  if exact_match_score(pred_span, gt_span):
    return 'Exact match', fscore
  if overlap:
    return 'Partial overlap', fscore
  else:
    return 'No overlap', fscore

def exact_match_score(prediction, ground_truth):
  return (normalize_answer(prediction) == normalize_answer(ground_truth))

def display_counter(title, c, c2=None):
  print(title)
  for key, _ in c.most_common():
    if c2:
      print('%s: %d / %d, %.1f%%, F1: %.1f' % (
        key, c[key], sum(c.values()), c[key] * 100. / sum(c.values()), sum(c2[key]) * 100. / len(c2[key])))
    else:
      print('%s: %d / %d, %.1f%%' % (key, c[key], sum(c.values()), c[key] * 100. / sum(c.values())))

def leave_one_out_max(prediction, ground_truths, article):
  if len(ground_truths) == 1:
    return metric_max_over_ground_truths(prediction, ground_truths, article)[1]
  else:
    t_f1 = []
    # leave out one ref every time
    for i in range(len(ground_truths)):
      idxes = list(range(len(ground_truths)))
      idxes.pop(i)
      refs = [ground_truths[z] for z in idxes]
      t_f1.append(metric_max_over_ground_truths(prediction, refs, article)[1])
  return 1.0 * sum(t_f1) / len(t_f1)


def metric_max_over_ground_truths(prediction, ground_truths, article):
  scores_for_ground_truths = []
  for ground_truth in ground_truths:
    score = compute_span_overlap(prediction, ground_truth, article)
    scores_for_ground_truths.append(score)
  return max(scores_for_ground_truths, key=lambda x: x[1])


def handle_cannot(refs):
  num_cannot = 0
  num_spans = 0
  for ref in refs:
    if ref == 'غیرقابل‌پاسخ':
      num_cannot += 1
    else:
      num_spans += 1
  if num_cannot >= num_spans:
    refs = ['CANNOTANSWER']
  else:
    refs = [x for x in refs if x != 'غیرقابل‌پاسخ']
  return refs


def leave_one_out(refs):
  if len(refs) == 1:
    return 1.
  splits = []
  for r in refs:
    splits.append(r.split())
  t_f1 = 0.0
  for i in range(len(refs)):
    m_f1 = 0
    for j in range(len(refs)):
      if i == j:
        continue
      f1_ij = f1_score(refs[i], refs[j])
      if f1_ij > m_f1:
        m_f1 = f1_ij
    t_f1 += m_f1
  return t_f1 / len(refs)






def eval_fn(val_results, model_results, verbose):
  span_overlap_stats = Counter()
  sentence_overlap = 0.
  para_overlap = 0.
  total_qs = 0.
  f1_stats = defaultdict(list)
  unfiltered_f1s = []
  total_dials = 0.
  unanswerables = []
  for p in val_results:
    for par in p['paragraphs']:
      did = par['id']
      qa_list = par['qas']
      good_dial = 1.
      for qa in qa_list:
        q_idx = qa['id']
        val_spans = [anss['text'] for anss in qa['answers']]
        val_spans = handle_cannot(val_spans)
        hf1 = leave_one_out(val_spans)

        if did not in model_results or q_idx not in model_results[did]:
          # print(did, q_idx, 'no prediction for this dialogue id')
          good_dial = 0
          f1_stats['NO ANSWER'].append(0.0)
          if val_spans == ['غیرقابل‌پاسخ']:
            unanswerables.append(0.0)
          total_qs += 1
          unfiltered_f1s.append(0.0)
          if hf1 >= .4:
            human_f1.append(hf1)
          continue

        pred_span, pred_yesno, pred_followup = model_results[did][q_idx]

        max_overlap, _ = metric_max_over_ground_truths( \
          pred_span, val_spans, par['context'])
        max_f1 = leave_one_out_max( \
          pred_span, val_spans, par['context'])
        unfiltered_f1s.append(max_f1)

        # dont eval on low agreement instances
        if hf1 < .4:
          continue

        human_f1.append(hf1)

        if val_spans == ['غیرقابل‌پاسخ']:
          unanswerables.append(max_f1)
        if verbose:
          print("-" * 20)
          print(pred_span)
          print(val_spans)
          print(max_f1)
          print("-" * 20)
        if max_f1 >= hf1:
          HEQ += 1.
        else:
          good_dial = 0.
        span_overlap_stats[max_overlap] += 1
        f1_stats[max_overlap].append(max_f1)
        total_qs += 1.
      DHEQ += good_dial
      total_dials += 1


  DHEQ_score = 100.0 * DHEQ / total_dials
  HEQ_score = 100.0 * HEQ / total_qs
  all_f1s = sum(f1_stats.values(), [])
  overall_f1 = 100.0 * sum(all_f1s) / len(all_f1s)
  unfiltered_f1 = 100.0 * sum(unfiltered_f1s) / len(unfiltered_f1s)
  unanswerable_score = (100.0 * sum(unanswerables) / len(unanswerables))
  metric_json = {"unfiltered_f1": unfiltered_f1, "f1": overall_f1, "HEQ": HEQ_score, "DHEQ": DHEQ_score, "yes/no": yesno_score, "followup": followup_score, "unanswerable_acc": unanswerable_score}
  if verbose:
    print("=======================")
    display_counter('Overlap Stats', span_overlap_stats, f1_stats)
  print("=======================")
  print('Overall F1: %.1f' % overall_f1)
  with open('val_report.txt', 'a') as f:
    f.write('Overall F1: %.1f' % overall_f1)

  print('Unfiltered F1 ({0:d} questions): {1:.1f}'.format(len(unfiltered_f1s), unfiltered_f1))
  print('Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)'.format(unanswerable_score, len(unanswerables)))
  print('Human F1: %.1f' % (100.0 * sum(human_f1) / len(human_f1)))
  print('Model F1 >= Human F1 (Questions): %d / %d, %.1f%%' % (HEQ, total_qs, 100.0 * HEQ / total_qs))
  print('Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials))
  print("=======================")
  output_string = 'Overall F1: %.1f\n' % overall_f1
  output_string += 'Yes/No Accuracy : %.1f\n' % yesno_score
  output_string += 'Followup Accuracy : %.1f\n' % followup_score
  output_string += 'Unfiltered F1 ({0:d} questions): {1:.1f}\n'.format(len(unfiltered_f1s), unfiltered_f1)
  output_string += 'Accuracy On Unanswerable Questions: {0:.1f} %% ({1:d} questions)\n'.format(unanswerable_score, len(unanswerables))
  output_string += 'Human F1: %.1f\n' % (100.0 * sum(human_f1) / len(human_f1))
  output_string += 'Model F1 >= Human F1 (Questions): %d / %d, %.1f%%\n' % (HEQ, total_qs, 100.0 * HEQ / total_qs)
  output_string += 'Model F1 >= Human F1 (Dialogs): %d / %d, %.1f%%' % (DHEQ, total_dials, 100.0 * DHEQ / total_dials)

  # save_prediction(epoch, train_step, output_string)

  return metric_json

def run_eval(mode, predicted_obj):
  new_eval_data = dict()
  mean_f1s_ = [[] for _ in range(30)]
  if mode == 'eval':
    main_data = eval_data
  elif mode == 'test':
    main_data = test_data

  for data in main_data:
    new_eval_data[str(data['id'])] = data


  res = {
      'question': [],
      'pred': [],
      'orig': [],
      'f1': [],
      'heq-q': []
  }

  f1s = []
  dialog_f1s = []
  dialog_hfs = []
  unans_score = []
  heq_q = []
  heq_d = []
  heq_m = []
  EM = []

  dialogs_f1s = dict()
  dialogs_hfs = dict()

  results = []
  for q_idx, model_answer in predicted_obj.answers.items():
    d_id, q_num = q_idx.split('#')[0], int(q_idx.split('#')[1])
    d_id, q_num = q_idx.split('#')[0], int(q_idx.split('#')[1])

    if d_id not in dialogs_f1s.keys():
      dialogs_f1s[d_id] = []
      dialogs_hfs[d_id] = []

    qa = new_eval_data[d_id]['qas'][q_num]
    answers_num = len(qa['answers'])
    orig_answers = [qa['answers'][qidx]['text'] for qidx in range(answers_num)]
    if 'غیرقابل‌پاسخ' in orig_answers:
      if model_answer.startswith('غیرقابل'):
        unans_score.append(1.0)
      else:
        unans_score.append(0.0)

    res['question'].append(qa['question'])
    res['pred'].append(model_answer)
    res['orig'].append(orig_answers)

    context = new_eval_data[d_id]['article']
    hf = qa['hf']
    f1s_ = [compute_span_overlap(model_answer, orig_answer, context)[1] for orig_answer in orig_answers]
    max_f1 = max(f1s_)
    dialog_hfs.append(hf)
    dialog_f1s.append(max_f1)
    f1s.append(max_f1)
    mean_f1s_[q_num].append(max_f1)

    if int(max_f1) == 1:
      EM.append(1.)
    else:
      EM.append(0.)

    if max_f1 >= hf:
      heq_q.append(1)
    else:
      heq_q.append(0)

    res['f1'].append(max_f1)
    res['heq-q'].append(heq_q)

    dialogs_f1s[d_id].append(max_f1)
    dialogs_hfs[d_id].append(hf)


  for key in dialogs_f1s.keys():
    dialog_f1s = dialogs_f1s[key]
    dialog_hfs = dialogs_hfs[key]

    heq_d.append(all(x >= y for x, y in zip(dialog_f1s, dialog_hfs)))
    heq_m.append(sum(dialog_f1s) >= sum(dialog_hfs))



  f1_score_ = sum(f1s) / len(f1s)
  heq_q_score_ = sum(heq_q) / len(heq_q)
  heq_m_score_ = sum(heq_m) / len(heq_m)
  heq_d_score_ = sum(heq_d) / len(heq_d)
  unans_score_ = sum(unans_score) / len(unans_score)
  EM_score_ = sum(EM) / len(EM)


  mean_f1s = [sum(mean_f1) / (1e-8 + len(mean_f1)) for mean_f1 in mean_f1s_]
  print(f'f1 score is {f1_score_}')
  print(f'HEQ-Q score is {heq_q_score_}')
  print(f'HEQ-M score is {heq_m_score_}')
  print(f'HEQ-D score is {heq_d_score_}')
  print(f'EM score is {EM_score_}')
  print(f'Unanswerable score is {unans_score_}')
  print(mean_f1s)
  return f1_score_, heq_q_score_, heq_d_score_, mean_f1s, res

# MT5



In [3]:
# %%capture
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("google/mt5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-base")

# MT0

In [4]:
%%capture
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-base")
model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/mt0-base")

# Rest

In [5]:
class CQA_DATA:

  def __init__(self,
               question,
               context,
               history,
               answer,
               qid,
               q_num,
               answer_start,
               answer_end,
               is_answerable):

    self.question = question
    self.context = context
    self.answer = answer
    self.history = history
    self.qid = qid
    self.q_num = q_num
    self.answer_start = answer_start
    self.answer_end = answer_end
    self.is_answerable = is_answerable
    self.cleaned_context = None
    self.cleaned_answer = {
        'text': self.answer,
        'start': self.answer_start,
        'end': self.answer_end
    }
    self.cleaned_context = context

  def __repr__(self):
    repr = ''
    repr += 'context -> ' + self.context[:100] + '\n'
    repr += 'question ->' + self.question + '\n'
    repr += 'question id ->' + str(self.qid) + '\n'
    repr += 'turn_number ->' + str(self.q_num) + '\n'
    repr += 'answer ->' + self.answer + '\n'
    return repr

class Feature:

  def __init__(self,
               qid,
               question_part,
               input_ids,
               attention_mask,
               feature_answerability,
               offset_mappings,
               max_context_dict,
               label_tokens,
               is_answerable,
               context,
               cleaned_context,
               context_start,
               context_end,
               example_start_char,
               example_end_char,
               example_answer):

    self.qid = qid
    self.question_part = question_part
    self.input_ids = input_ids
    self.attention_mask = attention_mask
    self.feature_answerability = feature_answerability
    self.offset_mappings = offset_mappings
    self.max_context_dict = max_context_dict
    self.label_tokens = label_tokens
    self.is_answerable = is_answerable
    self.context = context
    self.cleaned_context = cleaned_context
    self.context_start = context_start
    self.context_end = context_end
    self.example_start_char = example_start_char
    self.example_end_char = example_end_char
    self.example_answer = example_answer

  def __repr__(self):
    repr = ''
    repr += 'qid --> ' + str(self.qid) + '\n'
    repr += 'quesion part --> ' + str(self.question_part) + '\n'
    repr += 'answer --> ' + self.example_answer + '\n'
    return repr

"""# Examples"""

def make_examples(data, data_type, num_sample, clean_samples=True):
  examples = []
  each_file_size = 1000
  example_file_index = 0
  data_dir = f'examples/{data_type}/'

  for dialog_num, dialog in enumerate(tqdm(data[:num_sample ], leave=False, position=0)):
    dialog_history = []
    dialog_container = []
    dialog_id = dialog['id']
    title = dialog['title']
    context = dialog['article']
    dialog_len = len(dialog['qas'])

    for q_num, qa in enumerate(dialog['qas']):
      history = []
      question = qa['rewritten_question']
      answer = qa['answers'][0]
      is_answerable = False if answer['text'] == 'غیرقابل‌پاسخ' else True

      if not is_answerable:
        answer_start = 0
        answer_end = 0

      if not q_num == 0:
        history = deepcopy(dialog_history)

      qid = f'{dialog_id}#{q_num}'
      cqa_example = CQA_DATA(question=question,
                             context=context[ :-13],
                             history=history,
                             answer=answer['text'],
                             qid=qid,
                             q_num=q_num,
                             answer_start=answer['start'],
                             answer_end=answer['end'],
                             is_answerable=is_answerable)

      examples.append(cqa_example)
      dialog_history.append([cqa_example.question, cqa_example.answer])

    if (dialog_num + 1) % each_file_size == 0:
      filename = f'{data_type}_examples_' + str(example_file_index) + '.bin'
      save_data(examples, os.path.join(data_dir, filename))
      example_file_index += 1
      examples = []

  if examples != []:
    filename = f'{data_type}_examples_' + str(example_file_index) + '.bin'
    save_data(examples, os.path.join(data_dir, filename))

In [6]:
def make_features(data_type):
  data_dir = f'examples/{data_type}/'
  example_files = os.listdir(data_dir)
  example_files = [os.path.join(data_dir, example_file) for example_file in example_files]
  features_list = []
  features_dir = f'features/{data_type}/'
  current_file_index = 0
  max_history_to_consider = 5
  offset_ = 1

  for file_index, filename in enumerate(example_files):
    examples = load_data(filename)
    for example in examples:
      example_features = []
      concatenated_question = []

      # concat history
      for hist in example.history[-max_history_to_consider:]:
        concatenated_question.append(hist[0])

      # append current question to concatenated question
      concatenated_question.append(example.question)

      # make string out of concatenated question
      concatenated_question = ' '.join(concatenated_question)

      text_tokens = tokenizer(
          concatenated_question,
          example.cleaned_context,
          max_length=512,
          padding='max_length',
          truncation='only_second',
          return_overflowing_tokens=True,
          return_offsets_mapping=True,
          stride=128)



      # find start and end of context
      for idx in range(len(text_tokens['input_ids'])):
        found_start = False
        found_end = False
        context_start = 0
        context_end = 511
        max_context_dict = {}

        for token_idx, token in enumerate(text_tokens['offset_mapping'][idx]):
          if token[0] == 0 and token[1] == 0:
            context_start = token_idx + offset_
            break

        for token_idx, token in enumerate(text_tokens['offset_mapping'][idx][context_start:]):
          if token[0] == 0 and token[1] == 0:
            context_end = token_idx + context_start - 1
            break

        chunk_offset_mapping = text_tokens['offset_mapping'][idx]
        for context_idx, data in enumerate(chunk_offset_mapping[context_start: context_end + 1]):
          max_context_dict[f'({data[0]},{data[1]})'] = min(context_idx, context_end - context_idx) + (context_end - context_start + 1) * .01

        # find and mark current question answer
        last_token = None
        for token_idx, token in enumerate(chunk_offset_mapping[context_start: context_end + 1]):
          if token[0] == example.cleaned_answer['start'] and not found_start:
            found_start = True
            start = token_idx + context_start
          elif last_token and last_token[0] < example.cleaned_answer['start'] and token[0] > example.cleaned_answer['start']:
            found_start = True
            start = (token_idx - 1) + context_start
          if token[1] == example.cleaned_answer['end'] and not found_end:
            found_end = True
            end = token_idx + context_start
          elif last_token and last_token[1] < example.cleaned_answer['end'] and token[1] > example.cleaned_answer['end'] and last_token:
            found_end = True
            end = token_idx + context_start
          last_token = token


        # add feature to features list
        if end < start and found_start and found_end:
          assert False, 'start and end do not match'

        elif ((not found_start) or (not found_end)) and data_type != 'train':
          start, end = 0, 0

        if (not found_start) or (not found_end):
          answer = 'غیرقابل‌پاسخ'
        else:
          answer = example.answer

        feature_answerability = False if answer == 'غیرقابل‌پاسخ' else True

        label_tokens = tokenizer(
          example.answer,
          truncation=True)['input_ids']

        feature_answerability = torch.LongTensor([feature_answerability])

        # print(answerability)
        # print(answer)
        # print('-----------')

        example_features.append(Feature(example.qid,
                                          idx,
                                          text_tokens['input_ids'][idx],
                                          text_tokens['attention_mask'][idx],
                                          feature_answerability,
                                          text_tokens['offset_mapping'][idx],
                                          max_context_dict,
                                          label_tokens,
                                          example.is_answerable,
                                          example.context,
                                          example.cleaned_context,
                                          context_start,
                                          context_end,
                                          example.answer_start,
                                          example.answer_end,
                                          example.answer))
      features_list.extend(example_features)
  filename = f'{data_type}_features_' + str(file_index) + '.bin'
  print(len(features_list))
  save_data(features_list, os.path.join(features_dir, filename))

make_examples(train_data, 'train', 1000000)
make_examples(eval_data, 'eval', 1000000)
make_examples(test_data, 'test', 1000000)
make_features('train')
make_features('eval')
make_features('test')

                                                   

18066
3675
3687


In [7]:
class DataManager:

  def __init__(self, current_file, current_index, data_dir, batch_size, shuffle=True):
    self.files = sorted(os.listdir(data_dir), key=lambda x: int(x.split('_')[2].split('.')[0]))
    self.files = list(map(lambda x: os.path.join(data_dir, x), self.files))
    self.shuffle = shuffle
    self.data_len = 0
    for filename in self.files:
      self.data_len += len(load_data(filename))
    self.batch_size = batch_size
    self.reset_datamanager(current_file, current_index)

  def reset_datamanager(self, current_file_index, current_index):
    self.current_index = current_index
    self.current_file_index = current_file_index
    self.features = self.load_data_file(self.files[self.current_file_index])

  def load_data_file(self, filename):
    if self.shuffle:
      data = load_data(filename)
      random.shuffle(data)
      return data
    else:
      return load_data(filename)

  def next(self):
    temp = self.features[self.current_index:self.current_index + self.batch_size]
    self.temp = temp
    self.current_index += self.batch_size
    if self.current_index >= len(self.features):
      self.current_index = 0
      self.current_file_index += 1
      if self.current_file_index == len(self.files):
        self.reset_datamanager(current_file_index=0, current_index=0)
        return temp, True
      else:
        self.features = self.load_data_file(self.files[self.current_file_index])
    return temp, False


class DataLoader:

  def __init__(self, current_file, current_index, batch_size, shuffle=True, data_type='train'):
    self.data_type = data_type
    self.batch_size = batch_size
    self.data_manager = DataManager(current_file, current_index, f'features/{data_type}/', batch_size, shuffle)

  def __iter__(self):
    self.stop_iteration = False
    return self

  def __len__(self):
    return int(self.data_manager.data_len // self.batch_size)

  def reset_dataloader(self, current_file, current_index):
    self.data_manager.reset_datamanager(current_file, current_index)

  def features_2_tensor(self, features):
    x = dict()
    x['input_ids'] = torch.LongTensor([feature.input_ids for feature in features])
    x['attention_mask'] = torch.LongTensor([feature.attention_mask for feature in features])
    x['label_tokens'] = torch.LongTensor([feature.label_tokens for feature in features])
    x['feature_answerability'] = torch.LongTensor([feature.feature_answerability for feature in features])
    x['features'] = features
    return x

  def __next__(self):
    if self.stop_iteration:
      raise StopIteration
    features, self.stop_iteration = self.data_manager.next()
    return self.features_2_tensor(features)

In [8]:
"""# Saving Settings"""
hist_num_str = 11

drive_prefix = f'.'
drive_checkpoint_dir = 'Checkpoint/'
drive_log_dir = 'Log/'
checkpoint_dir = os.path.join(drive_prefix, drive_checkpoint_dir)
log_dir = os.path.join(drive_prefix, drive_log_dir)

meta_log_file = os.path.join(drive_prefix, drive_log_dir, 'test.txt')
prediction_file_prefix = os.path.join(drive_prefix, drive_log_dir, 'prediction_')
loss_log_file = os.path.join(drive_prefix, drive_log_dir, 'loss.txt')
mean_f1_file = os.path.join(drive_prefix, drive_log_dir, 'mean_f1.txt')
eval_res_file = os.path.join(drive_prefix, drive_log_dir, 'eval_result.json')
test_res_file = os.path.join(drive_prefix, drive_log_dir, 'test_result.json')

if not os.path.exists(drive_prefix):
  os.mkdir(drive_prefix)
  print('crated saved dir')
if not os.path.exists(checkpoint_dir):
  os.mkdir(checkpoint_dir)
if not os.path.exists(log_dir):
  os.mkdir(log_dir)

with open(meta_log_file, 'w') as f:
  pass
# check if drive is accessible
try:
   with open(os.path.join(drive_prefix, drive_log_dir, 'test.txt'), 'r') as f:
      pass
except:
  print('No Access to Drive')
  exit()


def print_loss(loss_collection, epoch, step):
  txt = f'EPOCH [{epoch + 1}/{epochs}] | STEP [{step}/{int(len(train_dataloader) / accumulation_steps)}] | Loss {round(sum(loss_collection) / len(loss_collection), 4)}'
  print(txt)

def save_loss(loss_collection, epoch, step):
  txt = f'EPOCH [{epoch + 1}/{epochs}] | STEP [{step}/{int(len(train_dataloader) / accumulation_steps)}] | Loss {round(sum(loss_collection) / len(loss_collection), 4)}'
  with open(loss_log_file, 'a') as f:
    f.write(txt)
    f.write('\n')



checkpoint_available = False


def save_prediction(epoch, step, prediction_log):
  with open(os.path.join(drive_prefix, drive_log_dir, 'prediction.txt'), 'a') as f:
    f.write(f'--------- EPOCH {epoch} STEP {step} ---------\n')
    f.write(prediction_log)
    f.write('\n')
    f.write('\n')

def save_checkpoint(filename):
  checkpoint_config = {
  'model_dict': model_p.state_dict()}
  torch.save(checkpoint_config, filename)

def load_checkpoint(current_checkpoint):
    checkpoint_config = torch.load(current_checkpoint)
    return checkpoint_config['model_dict']

In [9]:
class EvalProcessOutput:
  def __init__(self):
    self.features = dict()
    self.dialog_answers = defaultdict(list)
    self.answers = dict()
    self.dialog_len = dict()
    self.log = dict()

  def process_feature_output(self, feature, predicted_text, prob):
    qid = feature.qid
    q_num = int(qid.split('#')[1])
    dialog_id = qid.split('#')[0]

    if dialog_id not in self.dialog_len.keys():
      self.dialog_len[dialog_id] = 0

    if qid not in self.features.keys():
      self.features[qid] = []
      self.dialog_len[dialog_id] += 1
    self.features[qid].append([predicted_text, prob, feature.example_answer])

  def process_output(self):
    for dialog_id, dialog_len in self.dialog_len.items():
      self.dialog_answers[dialog_id] = ['' for _ in range(dialog_len)]

    for qid, text_prob in self.features.items():
      sorted_preds = sorted(text_prob, key=lambda x: x[1], reverse=True)
      best_pred = sorted_preds[0]
      q_num = int(qid.split('#')[1])
      dialog_id = qid.split('#')[0]
      self.dialog_answers[dialog_id][q_num] = best_pred[0]
      self.log[qid] = best_pred
      if best_pred[1] < .2:
        self.answers[qid] = 'غیرقابل‌پاسخ'
      else:
        self.answers[qid] = best_pred[0]

In [10]:
class Persian_Model(nn.Module):

  def __init__(self, transformer):
    super(Persian_Model, self).__init__()
    self.transformer = transformer
    self.answerability_head = nn.Linear(self.transformer.config.hidden_size, 2)
    self.answerability_loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([.5, 1.]).to(device))
    nn.init.normal_(self.answerability_head.weight, mean=.0, std=.02)


  def forward(self, input_ids, attention_mask, label_tokens, feature_answerability, mode):

    input_ids = input_ids.view(1, -1)
    attention_mask = attention_mask.view(1, -1)
    label_tokens = label_tokens.view(1, -1)
    feature_answerability = feature_answerability.view(-1)

    if mode == 'train':
      output = self.transformer(input_ids=input_ids,
                                attention_mask=attention_mask,
                                labels=label_tokens,
                                output_hidden_states=True)
      s_token_repr = output.encoder_hidden_states[-1][:, 0, :]
      answerability_score = self.answerability_head(s_token_repr)
      answerability_loss = self.answerability_loss_fn(answerability_score, feature_answerability)
      if feature_answerability:
        loss = output.loss + .5 * answerability_loss
      else:
        loss = answerability_loss
      return loss

    if mode == 'eval':
      output = model.generate(input_ids=input_ids.view(1, -1),
                              attention_mask=attention_mask.view(1, -1),
                              max_length=150,
                              do_sample=True,
                              top_p=.2,
                              top_k=100,
                              temperature=.7,
                              return_dict_in_generate=True,
                              output_hidden_states=True)
      s_token_repr = output.encoder_hidden_states[-1][:, 0, :]
      answerability_score = self.answerability_head(s_token_repr)
      return output, torch.softmax(answerability_score.view(-1), dim=0)[1].cpu().item()



In [None]:
epochs = 3
lr = 4e-5
beta_1 = .9
beta_2 = .999
eps = 1e-6
batch_size = 1
weight_decay = 0.0
accumulation_steps = 10
accumulation_counter = 0
q_scores = dict()
h_f1 = 0
k = 3
weight_decay = 0.0
f1_list = []
best_eval_f1 = 0.0

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model_p = Persian_Model(model).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)

loss_collection = []
train_dataloader = DataLoader(current_file=0, current_index=0, batch_size=batch_size, shuffle=True, data_type='train')
eval_dataloader = DataLoader(current_file=0, current_index=0, batch_size=1, shuffle=False, data_type='eval')
test_dataloader = DataLoader(current_file=0, current_index=0, batch_size=1, shuffle=False, data_type='test')
each_step_log = 1000
start_epoch = 0
start_step = 0
current_file = 0
current_index = 0


optimization_steps = int(epochs * len(train_dataloader) / accumulation_steps)
warmup_ratio = .1
warmup_steps = int(optimization_steps * warmup_ratio)

optimizer = AdamW(model_p.parameters(), lr=lr, betas=(beta_1,beta_2), eps=eps, weight_decay=weight_decay)
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=optimization_steps)

# laod checkpoint if available
if checkpoint_available:
  print('loading checkpoint')
  start_epoch, start_step, optimizer_dict, scheduler_dict, qa_model_dict, current_file, current_index = load_checkpoint()
  # load state dicts
  model_p.load_state_dict(qa_model_dict)
  optimizer.load_state_dict(optimizer_dict)
  scheduler.load_state_dict(scheduler_dict)

current_file_index_ = current_file
train_dataloader.reset_dataloader(current_file, current_index)
model_p.train()


for epoch in range(start_epoch, epochs):
  train_step = 1
  acc_loss = 0
  log_step = 0

  for data in train_dataloader:
    input_ids = data.pop('input_ids').to(device)
    attention_mask = data.pop('attention_mask').to(device)
    label_tokens = data.pop('label_tokens').to(device)
    feature_answerability = data.pop('feature_answerability').to(device)
    features = data.pop('features')

    # run output
    loss = model_p(input_ids=input_ids,
                     attention_mask=attention_mask,
                     label_tokens=label_tokens,
                     feature_answerability=feature_answerability,
                     mode='train')
    
    loss /= accumulation_steps

    acc_loss += loss.item()
    loss.backward()


    if train_step % each_step_log == 0:
      print_loss(loss_collection, epoch, log_step + 1)
      save_loss(loss_collection, epoch, log_step + 1)
      loss_collection = []

    accumulation_counter += 1
    if accumulation_counter % accumulation_steps == 0:
      loss_collection.append(acc_loss)
      acc_loss = 0
      log_step += 1
      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      torch.cuda.empty_cache()
      accumulation_counter = 0

    train_step += 1

  q_scores = dict()
  model_p.eval()
  turn_f1s = [[] for _ in range(20)]
  print('-------------------- Evaluation --------------------')
  eval_p = EvalProcessOutput()
  P = []
  with torch.no_grad():
    for step, data in tqdm(enumerate(eval_dataloader), position=0, leave=True):
      input_ids = data.pop('input_ids').to(device)
      attention_mask = data.pop('attention_mask').to(device)
      label_tokens = data.pop('label_tokens').to(device)
      feature_answerability = data.pop('feature_answerability').to(device)
      features = data.pop('features')

      # run output
      output, score = model_p(input_ids=input_ids,
                     attention_mask=attention_mask,
                     label_tokens=label_tokens,
                     feature_answerability=feature_answerability,
                     mode='eval')
      predicted_text = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
      eval_p.process_feature_output(features[0], predicted_text[0], score)

  eval_p.process_output()
  f1, heq_q, heq_d, mean_f1s, res = run_eval('eval', eval_p)
  f1_list.append(f1)
  if f1 > best_eval_f1:
    print('saving model...')
    best_eval_checkpoint = f'best_model_{epoch}.pt'
    save_checkpoint(best_eval_checkpoint)
    print('model saved successfully')
    best_eval_f1 = f1
    with open(eval_res_file, 'w') as f:
      json.dump(res, f)

  print('Best Eval F1', best_eval_f1)
  early_stop = all([f1_list[-k] > i for i in f1_list[-k+1:]]) if epoch + 1 >= k else False
  if early_stop:
    print('Early Stopping')
    break
  model_p.train()



if args.do_test:
    model_p.load_state_dict(checkpoint_config['model_dict'])
    model_p.eval()
    print('-------------------- Test Time --------------------')
    test_p = EvalProcessOutput()
    with torch.no_grad():
        for step, data in tqdm(enumerate(test_dataloader), position=0, leave=True):
          input_ids = data.pop('input_ids').to(device)
          attention_mask = data.pop('attention_mask').to(device)
          label_tokens = data.pop('label_tokens').to(device)
          feature_answerability = data.pop('feature_answerability').to(device)
          features = data.pop('features')

          # run output
          output, score = model_p(input_ids=input_ids,
                        attention_mask=attention_mask,
                        label_tokens=label_tokens,
                        feature_answerability=feature_answerability,
                        mode='eval')
          predicted_text = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
          test_p.process_feature_output(features[0], predicted_text[0], score)


    test_p.process_output()
    f1, heq_q, heq_d, mean_f1s, res = run_eval('test', test_p)
    with open(test_res_file, 'w') as f:
        json.dump(res, f)


EPOCH [1/3] | STEP [100/1806] | Loss 1.1641
EPOCH [1/3] | STEP [200/1806] | Loss 0.9241
EPOCH [1/3] | STEP [300/1806] | Loss 0.7988
EPOCH [1/3] | STEP [400/1806] | Loss 0.7405
EPOCH [1/3] | STEP [500/1806] | Loss 0.6752
EPOCH [1/3] | STEP [600/1806] | Loss 0.6791
EPOCH [1/3] | STEP [700/1806] | Loss 0.6402
EPOCH [1/3] | STEP [800/1806] | Loss 0.6358
EPOCH [1/3] | STEP [900/1806] | Loss 0.6318
EPOCH [1/3] | STEP [1000/1806] | Loss 0.621
EPOCH [1/3] | STEP [1100/1806] | Loss 0.6231
EPOCH [1/3] | STEP [1200/1806] | Loss 0.6089
EPOCH [1/3] | STEP [1300/1806] | Loss 0.6169
EPOCH [1/3] | STEP [1400/1806] | Loss 0.5915
EPOCH [1/3] | STEP [1500/1806] | Loss 0.5959
EPOCH [1/3] | STEP [1600/1806] | Loss 0.5817
EPOCH [1/3] | STEP [1700/1806] | Loss 0.5965
EPOCH [1/3] | STEP [1800/1806] | Loss 0.618
-------------------- Evaluation --------------------


3675it [28:46,  2.13it/s]


f1 score is 0.243096101098699
HEQ-Q score is 0.20384615384615384
HEQ-M score is 0.0
HEQ-D score is 0.0
EM score is 0.1276923076923077
Unanswerable score is 0.419811320754717
[0.4403444845795336, 0.3374695543878552, 0.2633324948299554, 0.18929478785160694, 0.2302177242296849, 0.21770963590263157, 0.16041371982438704, 0.2093665690219475, 0.15129209134808966, 0.20788297012551063, 0.22073011311671292, 0.23650798192756442, 0.20477611932107465, 0.17169790137972582, 0.3040374717281253, 0.02776249003213817, 0.3738977066078778, 0.406249998984375, 0.0, 0.3333333322222222, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
saving model...
model saved successfully
Best Eval F1 0.243096101098699
EPOCH [2/3] | STEP [101/1806] | Loss 0.5758
EPOCH [2/3] | STEP [201/1806] | Loss 0.5411
EPOCH [2/3] | STEP [301/1806] | Loss 0.5346
EPOCH [2/3] | STEP [401/1806] | Loss 0.5479
EPOCH [2/3] | STEP [501/1806] | Loss 0.5269
EPOCH [2/3] | STEP [601/1806] | Loss 0.5211
EPOCH [2/3] | STEP [701/1806] | Loss 0.5446
E

3675it [28:40,  2.14it/s]


f1 score is 0.26879865334871966
HEQ-Q score is 0.21461538461538462
HEQ-M score is 0.0
HEQ-D score is 0.0
EM score is 0.1346153846153846
Unanswerable score is 0.35377358490566035
[0.4424112827803949, 0.3804207987982704, 0.31119675319614576, 0.23549679788013508, 0.2656764876125048, 0.22702305581160148, 0.18573445040128375, 0.24374321403414656, 0.17519530025830432, 0.21454470429630138, 0.23047723982299068, 0.23007009409004398, 0.2272345910728886, 0.21981424137037636, 0.3225533687722894, 0.0, 0.34920634862433864, 0.406249998984375, 0.0, 0.3333333322222222, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
saving model...
model saved successfully
Best Eval F1 0.26879865334871966


In [None]:
A = eval_p.features

In [None]:
save_data(A, 'mmf.pk')