In [1]:
import argparse
import collections
import json
import numpy as np
import os
import re
import string
import sys
import glob
import csv, json
import pandas as pd

In [15]:
model_prefix = 'bert-base-uncased'
model_dir = 'bert-base-uncased'
layers = 12
preds_dir = "results/bert-base-uncased/epoch_2"
data_path = "dev-v2.0.json"
sample_size = 500

In [16]:
def make_qid_to_has_ans(dataset):
    qid_to_has_ans = {}
    for article in dataset:
        for p in article['paragraphs']:
            for qa in p['qas']: # question, answer, id
                qid_to_has_ans[qa['id']] = bool(qa['answers']) # answer: [] -> false; else ->　true
    return qid_to_has_ans
    # id ->　true/false

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s: return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))
    # true/ false 
    # idea:"i have fat ass" and "fat ass" -> after cleaning results in ("ass" == "ass") ->　true

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks) # intersect
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0: # intersect is null
        return 0
    precision = 1.0 * num_same / len(pred_toks) # TP / TP+FP
    recall = 1.0 * num_same / len(gold_toks) # TP / TP+FN
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def get_raw_scores(dataset, preds):
    exact_scores = {}
    f1_scores = {}
    for article in dataset:
        for p in article['paragraphs']:
            for qa in p['qas']:
                qid = qa['id']
                gold_answers = [a['text'] for a in qa['answers'] #'answers': [{'text': '10th and 11th centuries', 'answer_start': 94},
                            if normalize_answer(a['text'])] # check if still meaningful after cleaning
                if not gold_answers:
                  # For unanswerable questions, only correct answer is empty string
                    gold_answers = ['']
                if qid not in preds:
                    # print('Missing prediction for %s' % qid)
                    continue
                a_pred = preds[qid] # string
                # Take max over all gold answers
                exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) # gold_answer -> a['text'] list of strings # can have muliple answer
                f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
    return exact_scores, f1_scores

def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
    new_scores = {}
    for qid, s in scores.items():
        pred_na = na_probs[qid] > na_prob_thresh # na_probs[qid] <- scores
        if pred_na:
            new_scores[qid] = float(not qid_to_has_ans[qid])
        else:
            new_scores[qid] = s
    return new_scores

def make_eval_dict(exact_scores, f1_scores, qid_list=None):
    if not qid_list:
        total = len(exact_scores)
        return collections.OrderedDict([
            ('exact', 100.0 * sum(exact_scores.values()) / total),
            ('f1', 100.0 * sum(f1_scores.values()) / total),
            ('total', total),
        ])
    else:
        total = len(qid_list)
        return collections.OrderedDict([
            ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
            ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
            ('total', total),
        ])

def merge_eval(main_eval, new_eval, prefix): # exact:__,f1:__,total:__
    for k in new_eval: # k: exacr:__, f1:__, total:__
        main_eval['%s_%s' % (prefix, k)] = new_eval[k] # out_eval three more keys: has_ans_exact, has_ans_f1, has_ans_total

def main(data_file, pred_file):
    with open(data_file) as f: # load devset
        dataset_json = json.load(f)
        dataset = dataset_json['data'][0:sample_size-1]
    with open(pred_file) as f: # load pred.json
        preds = json.load(f)
    na_probs = {k: 0.0 for k in preds}
    #print(na_probs) # only id, id:0.0
  
    qid_to_has_ans = make_qid_to_has_ans(dataset)  # maps id to True/False
    has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] # for all qid that answers not [] in dataset
    no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] # for all qid that answer is [] in dataset

    has_ans_qids_in_pred = [qid for qid in has_ans_qids if qid in preds] # dataset not [] intersect pred not []
    no_ans_qids_in_pred = [qid for qid in no_ans_qids if qid in preds] # dataset [] intersect pred []
  
    exact_raw, f1_raw = get_raw_scores(dataset, preds)
  
    exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,1.0)
    f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 1.0)
  
    out_eval = make_eval_dict(exact_thresh, f1_thresh)

    if has_ans_qids: # if has_ans_qid is not empty ie []
        has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids_in_pred) # dataset not [] intersect pred not []
        merge_eval(out_eval, has_ans_eval, 'has_ans')
  
    if no_ans_qids:
        no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids_in_pred) # dataset [] intersect pred []
        merge_eval(out_eval, no_ans_eval, 'no_ans')

    exact, f1 = out_eval['exact'], out_eval['f1']
    exact_no_ans, f1_no_ans = out_eval['no_ans_exact'], out_eval['no_ans_f1']
    exact_has_ans, f1_has_ans = out_eval['has_ans_exact'], out_eval['has_ans_f1']

    return exact, f1, exact_no_ans, f1_no_ans, exact_has_ans, f1_has_ans


def convert_preds_to_json(preds_dir):
    
    csv_file = preds_dir + "/predict.csv"


    data = {}
    with open(csv_file,encoding = 'unicode_escape') as f:
        r = csv.DictReader(f)
        for row in r:
            id = row['Id']
            pred = row['Predicted']
            data[id] = pred

    x = json.dumps(data)
    f = open(preds_dir +"/predict.json", "w")
    f.write(x)
    f.close()

def evaluate(preds_dir, data_path):
    
    exact = 0
    f1 = 0
    exact_no_ans = 0
    f1_no_ans = 0
    exact_has_ans = 0
    f1_has_ans = 0

    json_file = preds_dir + "/predict.json"

    exact, f1, exact_no_ans, f1_no_ans, exact_has_ans, f1_has_ans = main(data_path, json_file)

    results = pd.DataFrame({'layer':0, 'exact':exact, 'f1':f1, 'exact_no_ans':exact_no_ans, 'f1_no_ans':f1_no_ans, 'exact_has_ans':exact_has_ans, 'f1_has_ans':f1_has_ans},index = [0])

    csv_name = "results.csv"

    results.to_csv(preds_dir + "/" +csv_name, index = False)

In [17]:
# Convert preds dir csv files to json
convert_preds_to_json(preds_dir = preds_dir)

# evaluate
evaluate(preds_dir = preds_dir, data_path = data_path)