In [None]:
import argparse
import json
import re
from typing import Dict, List, Any, Tuple, Union
from torchtext.data.metrics import bleu_score
from collections import Counter, defaultdict

from tqdm import tqdm
from google.colab import files, drive
import glob

In [None]:
# connect to google drive
drive.mount('/content/gdrive')

In [None]:
MODEL_TYPE = ["transformer", "cnns2s"] # cnns2s or transformer
COPY_FLAG = ["no_copy", "copy"] # or no_copy
DATASET_FAMILY = "LC-QuAD" # Monument or LC-QuAD
DATASET_NAME = "intermediary_question_tagged_all_no_resources" # DONT FORGET TO SET
REPORT_FILENAME = 'error_report_complete.json'

## Utils

In [None]:
def load_report(path: str):
  try:
    with open(path, 'r', encoding='utf-8') as f:
      answer_report = json.load(f)
    print(len(answer_report))

    if ('oov' not in path and len(answer_report) == 250):
      print(f"1 - MUST RERUN {path}")

  except Exception as e:
    print(f"2 - MUST RERUN {path}: {e}")
    answer_report = []

  return answer_report

In [None]:
def get_answers(path):
  report = load_report(path)
  answers = []

  if len(report) > 0 and 'dbpedia' in report[0]:
    raw_answers = [(entry['dbpedia']['predicted'], entry['dbpedia']['trg']) for entry in report]
  
  else:
    print(f"3 - MUST RERUN {path}")
    return []

  for a in raw_answers:

    gold = []
    pred = []

    predIsError = a[0]['is_error']
    goldIsError = a[1]['is_error']

    if predIsError or goldIsError:
      answers.append(([], []))
      continue
 
    predIsBoolean = not predIsError and 'boolean' in a[0]['query_result']
    goldIsBoolean = not goldIsError and 'boolean' in a[1]['query_result']

    if predIsBoolean:
      pred = [a[0]['query_result']['boolean']]
    

    if goldIsBoolean:
      gold = [a[1]['query_result']['boolean']]
    

    predIsEmptyList = not predIsError and not predIsBoolean and len(a[0]['query_result']['results']['bindings']) == 0
    goldIsEmptyList = not goldIsError and not goldIsBoolean and len(a[1]['query_result']['results']['bindings']) == 0

    if predIsEmptyList or goldIsEmptyList:
      answers.append(([], []))
      continue

    predIsCount = not predIsBoolean and 'value' in a[0]['query_result']['results']['bindings'][0]
    goldIsCount = not goldIsBoolean and 'value' in a[1]['query_result']['results']['bindings'][0]

    if predIsCount:
      pred = [a[0]['query_result']['results']['bindings'][0]['value']]

    if goldIsCount:
      gold = [a[1]['query_result']['results']['bindings'][0]['value']]

    
    if not predIsCount and not predIsBoolean:
      pred = a[0]['query_result']['results']['bindings']
    
    if not goldIsCount and not goldIsBoolean:
      gold = a[1]['query_result']['results']['bindings']


    answers.append((pred, gold))

  return answers

In [None]:
def get_prec_and_recall(path):
  answers = get_answers(path)

  if len(answers) == 0 :
    return 0,0

  prec = []
  recall = []

  for pred, gold in answers:
    if len(pred) == 0 or len(gold) == 0:
      prec.append(0)
      recall.append(0)
      continue

    count = 0
    for correct in gold:
      for found in pred:
        if correct == found:
          count += 1
          break

    recall.append(count / len(gold))
    prec.append(count / len(pred))
    
  out_prec = sum(prec) / len(prec)
  out_recall = sum(recall) / len(recall)

  return out_prec, out_recall

In [None]:
def get_f1(prec, recall):
  if prec == 0 or recall == 0:
    return 0

  else:
    return 2*(prec*recall) / (prec + recall)

## Main

In [None]:
for c in COPY_FLAG:
  for m in MODEL_TYPE:
    print('===============================================')
    MODELS_FOLDER = f"/content/gdrive/MyDrive/PRETRAINED/{m}/{c}/{DATASET_FAMILY}/{DATASET_NAME}/"

    models_paths = glob.glob(f"{MODELS_FOLDER}/*")
    reports_paths = [f'{m}/{REPORT_FILENAME}' for m in models_paths]
    print(MODELS_FOLDER, len(reports_paths))

    results = []
    for r in reports_paths:
      print(r)
      prec, recall = get_prec_and_recall(r)
      f1 = get_f1(prec, recall)

      results.append((prec, recall, f1))

    if len(results) == 3:
      print("PRECISION AVERAGE:", sum([r[0] for r in results]) / len(results) * 100)
      print("RECALL AVERAGE:", sum([r[1] for r in results]) / len(results) * 100)
      print("F1 AVERAGE:", sum([r[2] for r in results])/len(results) * 100)