In [None]:
!pip install python-Levenshtein
!pip install SPARQLWrapper

In [None]:
import argparse
import json
import re
from typing import Dict, List, Any, Tuple, Union
from torchtext.data.metrics import bleu_score
from collections import Counter, defaultdict

from tqdm import tqdm
from Levenshtein import distance as levenshtein_distance
from SPARQLWrapper import SPARQLWrapper, JSON
from google.colab import files, drive
import ssl
import glob
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
# connect to google drive
drive.mount('/content/gdrive')

In [None]:
MODEL_TYPE = "transformer" # cnns2s or transformer
COPY_FLAG = "copy" # or no_copy
DATASET_FAMILY = "LC-QuAD" # Monument or LC-QuAD
DATASET_NAME = "intermediary_question_tagged_all" # DONT FORGET TO SET
MODELS_FOLDER = f"/content/gdrive/MyDrive/PRETRAINED/{MODEL_TYPE}/{COPY_FLAG}/{DATASET_FAMILY}/{DATASET_NAME}/"

REPORT_FILENAME = 'error_report.json'

REGENERATE = False

## Consts

In [None]:
REPLACEMENTS = [
    ['dbo:', 'http://dbpedia.org/ontology/', 'dbo_'],
    ['dbp:', 'http://dbpedia.org/property/', 'dbp_'],
    ['dbc:', 'http://dbpedia.org/resource/Category:', 'dbc_'],
    ['dbr:', 'res:', 'http://dbpedia.org/resource/', 'dbr_'],
    ['dct:', 'dct_'],
    ['geo:', 'geo_'],
    ['georss:', 'georss_'],
    ['rdf:', 'rdf_'],
    ['rdfs:', 'rdfs_'],
    ['foaf:', 'foaf_'],
    ['owl:', 'owl_'],
    ['yago:', 'yago_'],
    ['skos:', 'skos_'],
    [' ( ', '  par_open  '],
    [' ) ', '  par_close  '],
    [' ( ', ' sparql_open '],
    [' ) ', ' sparql_close '],
    ['(', ' attr_open '],
    [') ', ')', ' attr_close '],
    ['{', ' brack_open '],
    ['}', ' brack_close '],
    [' . ', ' sep_dot '],
    ['. ', ' sep_dot '],
    ['?', 'var_'],
    ['*', 'wildcard'],
    [' <= ', ' math_leq '],
    [' >= ', ' math_geq '],
    [' < ', ' math_lt '],
    [' > ', ' math_gt ']
]


URI_SHORTENERS = [
    {
        'match': 'http://dbpedia.org/ontology/',
        'interm_sparql': 'dbo_',
        'pure_sparql': 'dbo:'
    },
    {
        'match': 'http://dbpedia.org/property/',
        'interm_sparql': 'dbp_',
        'pure_sparql': 'dbp:'
    },
    {
        'match': 'http://dbpedia.org/resource/Category:',
        'interm_sparql': 'dbc_',
        'pure_sparql': 'dbc:'
    },
    {
        'match': 'http://dbpedia.org/resource/',
        'interm_sparql': 'dbr_',
        'pure_sparql': 'dbr:'
    },
    {
        'match': 'dbo:',
        'interm_sparql': 'dbo_',
        'pure_sparql': 'dbo:'
    },
    {
        'match': 'dbp:',
        'interm_sparql': 'dbp_',
        'pure_sparql': 'dbp:'
    },
    {
        'match': 'dbc:',
        'interm_sparql': 'dbc_',
        'pure_sparql': 'dbc:'
    },
    {
        'match': 'dbr:',
        'interm_sparql': 'dbr_',
        'pure_sparql': 'dbr:'
    },
]


PURE_SPARQL_TO_INTERM = [
    {
        'pure_sparql': 'dct:',
        'interm_sparql': 'dct_'
    },
    {
        'pure_sparql': 'geo:',
        'interm_sparql': 'geo_'
    },
    {
        'pure_sparql': 'georss:',
        'interm_sparql': 'georss_'
    },
    {
        'pure_sparql': 'rdf:',
        'interm_sparql': 'rdf_'
    },
    {
        'pure_sparql': 'rdfs:',
        'interm_sparql': 'rdfs_'
    },
    {
        'pure_sparql': 'foaf:',
        'interm_sparql': 'foaf_'
    },
    {
        'pure_sparql': 'owl:',
        'interm_sparql': 'owl_'
    },
    {
        'pure_sparql': 'yago:',
        'interm_sparql': 'yago_'
    },
    {
        'pure_sparql': 'skos:',
        'interm_sparql': 'skos_'
    },
    {
        'pure_sparql': '(',
        'interm_sparql': '  par_open  '
    },
    {
        'pure_sparql': ')',
        'interm_sparql': '  par_close  '
    },
    {
        'pure_sparql': '{',
        'interm_sparql': '  brack_open  '
    },
    {
        'pure_sparql': '}',
        'interm_sparql': '  brack_close  '
    },
    {
        'pure_sparql': '.',
        'interm_sparql': '  sep_dot  '
    },
    {
        'pure_sparql': '?',
        'interm_sparql': ' var_'
    },
    {
        'pure_sparql': '*',
        'interm_sparql': '  wildcard  '
    },
    {
        'pure_sparql': ' <= ',
        'interm_sparql': ' math_leq '
    },
    {
        'pure_sparql': ' >= ',
        'interm_sparql': ' math_geq '
    },
    {
        'pure_sparql': ' < ',
        'interm_sparql': ' math_lt '
    },
    {
        'pure_sparql': ' > ',
        'interm_sparql': ' math_gt '
    }
]


PURE_SPARQL_TO_INTERM_RES = [
    {
        'pure_sparql': '(',
        'interm_sparql': ' attr_open '
    },
    {
        'pure_sparql': ')',
        'interm_sparql': ' attr_close '
    }
]


PLACEHOLDERS = {
    'dbo': '<ontology>',
    'dbp': '<property>',
    'dbc': '<category>',
    'dbr': '<resource>'
}

# when replacing var_uri by ?uri, sometimes catches resources containing var (ex: Sarovar_Bridge becomes Saro?Bridge). This regex is used to revert this change in resources but not in variables
CATCH_VAR_IN_RESOURCE_NAME_RE = re.compile('db[rocp]:[a-zA-Z0-9_]*?(\?)[a-zA-Z0-9]', flags = re.IGNORECASE)

# in interm sparql, ORDER BY DESC(?uri)  become _obd_ var_uri
REPLACE_ORDER_BY_RE = re.compile(".*((_ob[ad]_)\\s*(\\?[a-zA-Z]+))", flags = re.IGNORECASE)

# Catch resources in pure sparql queries
GET_RESOURCES_PURE_SPARQL_RE = re.compile("(db[orcp]:.*?)(?:[\s|}])", flags=re.IGNORECASE)

# splits a uri into the resource type part (entity, ontology, property, etc) and the resource name part
SPLIT_URI_RE = re.compile('https?:\/\/dbpedia.org\/(.*?)\/(.*)', re.IGNORECASE)

# This will capture everything that is between <>, especially useful for LCQUAD
FIND_RESOURCES_BTW_ANGLE_BRACKETS_RE = re.compile('\<(.*?)\>', re.IGNORECASE)

# This will capture everything except what is in resources
RE_TEMPLATE_EXCLUDE_SPECIFIC_RESOURCES = '(.*?)({resources}|$)'

# this is used to capture parentheses (attr_open and par_close) that follow a filter in interm sparql
REPLACE_FILTER_PAR_RE = re.compile("(?:FILTER)\s*?(par_open).*(par_close)", re.IGNORECASE)

# this is used to capture parentheses (attr_open and par_close) that follow a count or an order by in interm sparql
REPLACE_REST_PAR_RE = re.compile("(?:count|order by (?:asc|desc))\s*?(par_open).*?(par_close)", re.IGNORECASE)

# this is used to capture quotes that follow a regex in interm sparql to encode them into the 'quote' symbol
REPLACE_QUOTES_RE = re.compile("regex par_open var_[a-z]+,(').*?(')(?:,(').*?('))? par_close", re.IGNORECASE)

# most math clauses in raw interm sparql have no spaces, so this regex is used to insert some
INSERT_SPACES_MATH_RE = re.compile("(var_[a-zA-Z0-9]+)(math_[gl]t)(.*?)[$|\s]", re.IGNORECASE)

# Master regex to catch resources in interm sparql queries in ANY dataset
GET_RESOURCES_INTERM_SPARQL_RE = re.compile("(db[orcp]_.*?(?:(?:\s*?(?:attr|par)_open([a-z^db[orcp]_*?)(?:attr|par)_close [a-z^db[orcp]]*?)|(?:\s*?(?:attr|par)_open(.*?)(?:attr|par)_close)|(?:\s?_.*?)|(?:,?_.*?)|(?:\s*?\-.*?))*\??)(?:\s|brack_close)")

# https://stackoverflow.com/a/29922050 - split a resource name that is encoded in camel case into different words
CAMEL_CASE_SPLIT_RE = re.compile(r'[A-Z]?[a-z]+|[A-Z]+(?=[A-Z]|$)')

# Abbreviations for different types in sparql
RESOURCE_ABBRV = 'dbr'
PROPERTY_ABBRV  = 'dbp'
CLASS_ABBRV = 'dbo'

# Names for different types in sparql
RESOURCE_TYPE = 'resource'
PROPERTY_TYPE = 'property'
CLASS_TYPE = 'ontology'


# Exceptions to handle when correcting sep dots (useful for DBNQA because lots of formatting errors with sep dots)
EXCEPTIONS_NOT_REPLACE = [ 'dbr_Mode._Set._Clear.', 'dbr_Cand.theol.', 'dbr_Observe._Hack._Make.']
ENDS_EXCEPTIONS_NOT_REPLACE = ['dept.', 'bros.', 'litt.', 'corp.', 'gent.']
ENDS_EXCEPTIONS_REPLACE = ['hop.']

# Sparql explicit type for rdf:type, should not be counted as a resource
SPARQL_TYPE = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"

ENDPOINT = "http://dbpedia.org/sparql"
GRAPH = "http://dbpedia.org"

In [None]:
class Flags:
    dbr: bool = False
    dbp: bool = False
    dbc: bool = False
    dbo: bool = False


## Utils

In [None]:
def escape_uri_for_regex(uri: str) -> str:
    uri = uri.replace('(', '\(')
    uri = uri.replace(')', '\)')
    uri = uri.replace('.', '\.')
    uri = uri.replace('?', '\?')
    uri = uri.replace('*', '\*')
    uri = uri.replace('+', '\+')
    uri = uri.replace("'", "\'")

    return uri


def convert_to_interm_sparql_except_resources(match: re.Match) -> str:
    interm_sparql: str = match.group(0).lower()

    for r in PURE_SPARQL_TO_INTERM:
        interm_sparql = interm_sparql.replace(
            r['pure_sparql'], r['interm_sparql'])

    interm_sparql = re.sub('\s+', ' ', interm_sparql)
    return interm_sparql


def encode_resources_in_interm_sparql(match: re.Match, keep_uris: bool, flags: Flags = None) -> str:
    resource: str = match.group(0)

    if flags is not None:
        should_replace = flags.dbr and (resource.startswith('<dbr') or resource.startswith('<http://dbpedia.org/resource/'))
        should_replace = should_replace or flags.dbp and (resource.startswith('<dbp') or resource.startswith('<http://dbpedia.org/property/'))
        should_replace = should_replace or flags.dbc and (resource.startswith('<dbc') or resource.startswith('<http://dbpedia.org/category/'))
        should_replace = should_replace or flags.dbo and (resource.startswith('<dbo') or resource.startswith('<http://dbpedia.org/ontology/'))
    elif keep_uris and flags is None:
        raise ValueError("Keep uris is true but no flags provided")
    else:
        should_replace = False

    for r in URI_SHORTENERS:
        if keep_uris and should_replace:
            resource = resource.replace(r['match'], r['pure_sparql'])
        else:
            resource = resource.replace(r['match'], r['interm_sparql'])

    if not keep_uris:
        for r in PURE_SPARQL_TO_INTERM_RES:
            resource = resource.replace(r['pure_sparql'], r['interm_sparql'])

    resource = re.sub('\s+', ' ', resource)

    return resource[1:-1]


def lowercase_except_resources(match: re.Match) -> str:
    sparql: str = match.group(0).lower()
    return sparql


def encode_resources_in_pure_sparql(match: re.Match) -> str:
    resource: str = match.group(0)

    for r in URI_SHORTENERS:
        resource = resource.replace(r['match'], r['pure_sparql'])

    return resource[1:-1]


def replace_resource_with_uri(match: re.Match, expressions: List[str], replacement: str) -> str:
    to_replace: str = ' ' + match.group(1)
    resource: str = match.group(2)

    for expr in expressions:
        if expr in to_replace:
            to_replace = re.sub(f'(\s{escape_uri_for_regex(expr)})', f' {replacement}', to_replace)
            break

    return to_replace + resource


def correct_parentheses_interm_sparql(match: re.Match) -> str:
    dt = {'par_open': 'sparql_open', 'par_close': 'sparql_close'}
    whole_match: str = match.group(0)
    original_span = match.span(0)
    span_open, span_close = match.span(1), match.span(2)

    return whole_match[: span_open[0] - original_span[0]] + dt[match.group(1)] + whole_match[span_open[1] - original_span[0] : span_close[0]- original_span[0]] + dt[match.group(2)] + whole_match[span_close[1]- original_span[0]:]


def replace_quotes(match: re.Match) -> str:
    whole_match = list(match.group(0))
    replacement = list(' sparql_quote ')
    offset = match.span(0)[0]

    for g in range(len(match.groups()), 0, -1):
        span = match.span(g)
        if span[0] == -1 or span[1] == -1:
            continue

        whole_match[span[0] - offset : span[1] - offset] = replacement

    return ''.join(whole_match)


def correct_sep_dots(match: re.Match) -> str:
    e = match.group(0).strip()
    if e[-1] == '.':
        if e in EXCEPTIONS_NOT_REPLACE:
            return f' {match.group(0)} '

        elif e[-2].isnumeric() or e.split('_')[-1].lower() in ENDS_EXCEPTIONS_REPLACE:
            return f' {e[:-1]} sep_dot '

        elif len(e.split('.')[-2]) <= 4:
            return f' {match.group(0)} '

        elif len(e.split('_')[-1]) <= 4 or e.split('_')[-1].lower() in ENDS_EXCEPTIONS_NOT_REPLACE:
            return f' {match.group(0)} '
        else:
            return f' {e[:-1]} sep_dot '

    return f' {match.group(0)} '


def remove_spaces_from_resources(match: re.Match) -> str:
    entity = match.group(0)
    entity = entity.replace(' ', '')
    entity = entity.replace('\\', '')
    return f' {entity} '


def insert_spaces_math(match: re.Match) -> str:
    out = ' '
    for m in match.groups():
        out += m
        out += ' '
    return out


def insert_resources(m: re.Match, resources: List[str], v:bool = False) -> str:
    whole_match = list(m.group(0))
    if v:
        print("WM", whole_match)

    offset = 0

    for r in range(len(m.groups())):
        replacement_span = m.span(r + 1)
        if v:
            print("group:", m.groups(r+1))
            print("repl span:", replacement_span)
            print('replace:', whole_match[replacement_span[0] + offset:replacement_span[1] + offset])
            print('by:', resources[r])

        whole_match[replacement_span[0] + offset:replacement_span[1] + offset] = list(resources[r])
        if v:
            print('replaced:', whole_match)

        offset += len(resources[r]) - (replacement_span[1] -  replacement_span[0]) # offset  = offset +?

        if v:
            print('offset:', whole_match)

    return ''.join(whole_match)


def abstract_resources(match: re.Match) -> str:
    NO_REPLACE_PURE = ['dbo:abstract', 'dbp:length', 'dbo:location', 'dbo:designer', 'dbp:complete', 'dbp:nativename', 'dbp:height']
    NO_REPLACE_INTERM = ['dbo_abstract', 'dbp_length', 'dbo_location', 'dbo_designer', 'dbp_complete', 'dbp_nativename', 'dbp_height']

    if match.group(0).strip().lower() in NO_REPLACE_PURE or match.group(0).strip().lower() in NO_REPLACE_INTERM:
        return f' {match.group(0)} '

    prefix = match.group(0)[:3]
    if prefix not in PLACEHOLDERS:
        raise ValueError(f"UNKNOWN PREFIX: {prefix}")

    return f' {PLACEHOLDERS[prefix]} '


### Escape SPARQL

In [None]:
def correct_interm_sparql(interm_sparql: str) -> str:
        should_print = 'attr_open' in interm_sparql
        interm_sparql = interm_sparql.replace('attr_open', ' par_open ')
        interm_sparql = interm_sparql.replace('attr_close', ' par_close ')

        interm_sparql = re.sub('\s+', ' ', interm_sparql)

        interm_sparql = interm_sparql.replace('var_uri.', 'var_uri sep_dot')
        interm_sparql = interm_sparql.replace('brack_open', ' brack_open ')
        interm_sparql = interm_sparql.replace('brack_close', ' brack_close ')
        interm_sparql = interm_sparql.replace('dbp_length', ' dbp_length ')

        interm_sparql = REPLACE_FILTER_PAR_RE.sub(
            correct_parentheses_interm_sparql, interm_sparql)
        interm_sparql = REPLACE_REST_PAR_RE.sub(
            correct_parentheses_interm_sparql, interm_sparql)
        interm_sparql = REPLACE_QUOTES_RE.sub(replace_quotes, interm_sparql)
        interm_sparql = GET_RESOURCES_INTERM_SPARQL_RE.sub(
            correct_sep_dots, interm_sparql)
        interm_sparql = INSERT_SPACES_MATH_RE.sub(
            insert_spaces_math, interm_sparql)

        interm_sparql = interm_sparql.replace('FILTER', 'filter')
        interm_sparql = interm_sparql.replace('COUNT', 'count')
        interm_sparql = interm_sparql.replace('UNION', 'union')

        interm_sparql = interm_sparql.replace('%3F', '?')

        interm_sparql = re.sub('\s+', ' ', interm_sparql)
        interm_sparql = interm_sparql.strip()

        interm_sparql = interm_sparql.replace('par_open', 'attr_open')
        interm_sparql = interm_sparql.replace('par_close', 'attr_close')
        interm_sparql = interm_sparql.replace('sparql_open', 'par_open')
        interm_sparql = interm_sparql.replace('sparql_close', 'par_close')

        return interm_sparql

In [None]:
def reverse_replacements(query: str) -> str:
    for r in REPLACEMENTS:
        original = r[0]
        encoding = r[-1]
        query = query.replace(encoding, original)
        stripped_encoding = str.strip(encoding)
        query = query.replace(stripped_encoding, original)

    return query


def escape_order_by(query: str) -> str:
    matches = REPLACE_ORDER_BY_RE.findall(query)

    if len(matches) == 0:
        return query

    matches = matches[0]

    if len(matches) > 3:
        raise ValueError(f"The query '{query}' has more than one order by!")

    if matches[1] == "_oba_":
        order_by = "ORDER BY ASC("
    elif matches[1] == "_obd_":
        order_by = "ORDER BY DESC("

    order_by_str = order_by + matches[2] + ")"
    query = query.replace(matches[0], order_by_str)

    return query


def remove_spaces_from_resources(match: re.Match) -> str:
    entity = match.group(0)
    entity = entity.replace(" ", "")
    entity = entity.replace("\\", "")
    entity = entity.replace(' attr_dot ', ".")
    entity = entity.replace(' attr_dot', ".")
    entity = entity.replace('attr_dot ', ".")
    entity = entity.replace('attr_dot', ".")
    return f" {entity} "


def add_var_in_resource_names(match: re.Match) -> str:
    return str(match.group(0).replace('?', 'var_'))


def interm_sparql_to_pure_sparql(interm_query: str):
    interm_query = correct_interm_sparql(interm_query)
    query = reverse_replacements(interm_query)
    query = escape_order_by(query)
    query = CATCH_VAR_IN_RESOURCE_NAME_RE.sub(add_var_in_resource_names, query)
    query = query.replace("where{", "where {")
    query = query.replace("}", " } ")
    query = query.replace("FILTER", "filter")
    query = query.replace("COUNT", "count")
    query = query.replace("UNION", "union")

    query = GET_RESOURCES_PURE_SPARQL_RE.sub(remove_spaces_from_resources, query)

    query = query.replace("dbp:length", " dbp:length ")
    query = re.sub(" ?%3F", "?", query)
    query = re.sub("\s+", " ", query)
    query = query.strip()
    return query


def generate_pure_sparql(interm_sparql: List[str]) -> List[str]:
    pure_sparql = []

    for interm_query in interm_sparql:
        query = interm_sparql_to_pure_sparql(interm_query)
        pure_sparql.append(query)
    
    return pure_sparql


def escape_parentheses_in_entities(match: re.Match) -> str:
    resource: str = match.group(0)
    resource = resource.replace('(', '\\(')
    resource = resource.replace(')', '\\)')
    return resource


def escape_parentheses(query: str) -> str:
    query = GET_RESOURCES_PURE_SPARQL_RE.sub(escape_parentheses_in_entities, query)
    return query


def escape_ampersands(query: str) -> str:
    amp = query.find('&')
    while amp > 0:
        if query[amp - 1] != '&' and query[amp + 1] != '&':
            query = query[:amp] + '\\' + query[amp:]
        amp = query.find('&', amp + 2)
    return query


def escape_dots_in_resources(match:re.Match) -> str:
    full_match: str = match.group(0)
    return full_match.replace('.', '\\.')


def escape_dots(query: str) -> str:
    query = GET_RESOURCES_PURE_SPARQL_RE.sub(escape_dots_in_resources, query)
    return query


def escape_plus(query: str) -> str:
    idx = query.find('+')
    while idx > 0:
        query = query[:idx] + '\\' + query[idx:]
        idx = query.find('+', idx + 2)
    return query


def escape_star(query: str) -> str:
    idx = query.find('*')
    while idx > 0:
        query = query[:idx] + '\\' + query[idx:]
        idx = query.find('*', idx + 2)
    return query


def escape_query(query: str) -> str:
    query = escape_parentheses(query)
    query = escape_ampersands(query)
    query = escape_dots(query)
    query = escape_plus(query)
    # query = escape_star(query)
    query = escape_order_by(query)
    query = query.replace("'", "\\'")
    query = query.replace(",", "\\,")
    query = query.replace("!", "\\!")
    query = query.replace("/", "\\/")
    return query


def escape_for_querying(pure_sparql: List[str]) -> List[str]:
    escaped_sparql = []
    for query in tqdm(pure_sparql):
        escaped_sparql.append(escape_query(query))
    return escaped_sparql


def convert_to_pure_sparql(in_path_intem_sparql: str, out_path_pure_sparql: str, escape: bool = False) -> None:
    interm_sparql = open(in_path_intem_sparql, 'r',
                         encoding="utf-8").read().strip().split('\n')

    pure_sparql = generate_pure_sparql(interm_sparql)

    if escape:
        pure_sparql = escape_for_querying(pure_sparql)

    with open(out_path_pure_sparql, 'w', encoding="utf-8") as f:
        f.writelines('\n'.join(pure_sparql))


def generate_pure_sparql(interm_sparql: str, to_pr = False) -> str:
    pure_sparql: str = interm_sparql_to_pure_sparql(interm_sparql)
    # if to_pr:
      # print('HI', pure_sparql)
    pure_sparql = escape_query(pure_sparql)
    pure_sparql = pure_sparql.replace(' var_b ', ' ?b ')
    pure_sparql = pure_sparql.replace('<', '')
    return pure_sparql.replace('>', '')


def generate_pure_sparql_for_report(partial_report: List[Dict[str, str]]) -> List[Dict[str, str]]:
    for id in range(len(partial_report)):
        partial_report[id]['pure_trg'] = generate_pure_sparql(partial_report[id]['trg'], partial_report[id]['id'] == '4016')
        partial_report[id]['pure_predicted']= generate_pure_sparql(partial_report[id]['predicted'])

    return partial_report

### Query DBpedia

In [None]:
def query_dbpedia(query: str) -> Dict[Any, Any]:
    sparql = SPARQLWrapper(ENDPOINT)
    sparql.setReturnFormat(JSON)

    sparql.setQuery(query)

    response: Dict[Any, Any] = sparql.query().convert()
    return response

In [None]:
def query_dbpedia_for_report(complete_report: List[Dict[str, str]], predicted=False) -> List[Dict[str, str]]:
    dbpedia_key = 'predicted' if predicted else 'trg'
    pure_sparql_key = 'pure_predicted' if predicted else 'pure_trg'

    for entry in tqdm(complete_report):
        try:
            dbpedia_data = entry.get('dbpedia', {'predicted': {}, 'trg': {}})
            dbpedia_data[dbpedia_key]['query_result'] = query_dbpedia(entry[pure_sparql_key])
            dbpedia_data[dbpedia_key]['is_error'] = False
            entry['dbpedia'] = dbpedia_data

        except Exception as error:
            print(f"[ERROR] at query id {entry['id']}:")
            print(error)
            dbpedia_data = entry.get('dbpedia', {'predicted': {}, 'trg': {}})
            dbpedia_data[dbpedia_key]['query_result'] = []
            dbpedia_data[dbpedia_key]['is_error'] = True
            entry['dbpedia'] = dbpedia_data

    return complete_report


### Get Metrics

In [None]:
def get_full_report_metrics(report: List[Dict]):
    report_metrics = {}
    # bleu score
    predicted = [entry['predicted'].replace(':', '_').split() for entry in report]
    trg = [[entry['trg'].replace(':', '_').split()] for entry in report]

    report_metrics['bleu_score'] = f"{bleu_score(predicted, trg)}"

    # answer accuracy
    if 'dbpedia' in report[0]:
        answers = [(entry['dbpedia']['predicted'], entry['dbpedia']['trg']) for entry in report]
        error_predicted_count = 0
        error_ground_truth_count = 0
        correct_answer_count = 0
        count_empty = 0

        for a in answers:
            if a[0] == a[1]:
              correct_answer_count += 1
            
        report_metrics['answer_accuracy'] = f"{correct_answer_count/len(report)}"
        
    else:
        print("No info available on answer accuracy")

    return report_metrics

In [None]:
def get_precision_recall(report: List[Dict]):
  report_metrics = {}
  
  tn_cnt = 0
  tp_cnt = 0
  fn_cnt = 0
  fp_cnt = 0
  
  if 'dbpedia' in report[0]:
    answers = [(entry['dbpedia']['predicted'], entry['dbpedia']['trg']) for entry in report]

    # suivant: https://github.com/semantic-systems/NLIWOD
    # ex: https://github.com/semantic-systems/NLIWOD/blob/8871cd937420350be135e6542a6f3f6a0cc1cc4c/qa.ml/src/main/java/org/aksw/mlqa/utils/PrintCSV.java
    # TN: pred vide et gold vide
    # FN: pred vide et gold pas vide
    # TP: pred pas vide == gold pas vide
    # FP: pred pas vide et gold vide OU pred pas vide != gold pas vide


    for a in answers:
      try:
        predIsError = a[0]['is_error']
        goldIsError = a[1]['is_error']

        predIsBoolean = not predIsError and 'boolean' in a[0]['query_result']
        goldIsBoolean = not goldIsError and 'boolean' in a[1]['query_result']

        predBooleanValue = predIsBoolean and a[0]['query_result']['boolean']
        goldBooleanValue = predIsBoolean and a[1]['query_result']['boolean']

        predIsEmptyList = not predIsError and not predIsBoolean and len(a[0]['query_result']['results']['bindings']) == 0
        goldIsEmptyList = not goldIsError and not goldIsBoolean and len(a[1]['query_result']['results']['bindings']) == 0

        predIsEmptyCount = not predIsError and not predIsBoolean and not predIsEmptyList and 'value' in a[0]['query_result']['results']['bindings'][0] and a[0]['query_result']['results']['bindings'][0]['value'] == 0
        goldIsEmptyCount = not predIsError and not goldIsBoolean and not goldIsEmptyList and 'value' in a[1]['query_result']['results']['bindings'][0] and a[1]['query_result']['results']['bindings'][0]['value'] == 0

        predIsEmpty = predIsError or (predIsBoolean and not predBooleanValue) or predIsEmptyCount or predIsEmptyList
        goldIsEmpty = goldIsError or (goldIsBoolean and not goldBooleanValue) or goldIsEmptyCount or goldIsEmptyList

        for a in answers:
            if predIsEmpty and goldIsEmpty:
              tn_cnt += 1

            elif predIsEmpty and not goldIsEmpty:
              fn_cnt += 1
            
            elif (not predIsEmpty and goldIsEmpty) or a[0] != a[1]:
              fp_cnt += 1
            
            elif a[0] == a[1]:
              tp_cnt += 1

            else:
              print("WHAT HAPPENED HERE")

      except Exception as e: 
        print("ERROR")
        print(e)
        print(a[0])
        print(a[1])


    report_metrics['recall'] = tp_cnt/(tp_cnt+fn_cnt)
    report_metrics['precision'] = tp_cnt/(tp_cnt+fp_cnt)
  
  return report_metrics

In [None]:
def generate_report(error_report_path: str, run_template_metrics: bool, run_dbpedia: bool, out) -> List[Dict]:
    with open(error_report_path, 'r', encoding='utf-8') as f:
        report = json.load(f)

    complete_report = generate_pure_sparql_for_report(report)


    if run_dbpedia:
        print("QUERYING DBPEDIA FOR EXPECTED RESULT...")
        complete_report = query_dbpedia_for_report(complete_report, predicted=False)
        print("QUERYING DBPEDIA FOR PREDICTED RESULT...")
        complete_report = query_dbpedia_for_report(complete_report, predicted=True)

    with open(out, 'w') as f:
      json.dump(complete_report, f)

    #full_report_metrics = get_full_report_metrics(complete_report)
    full_report_metrics = get_precision_recall(complete_report)

    print('FULL REPORT:')
    for key, val in full_report_metrics.items():
        print(f'\t{key}: {val}')

    return full_report_metrics

## Main

In [None]:
models_paths = glob.glob(f"{MODELS_FOLDER}/*")
reports_paths = [f'{m}/{REPORT_FILENAME}' for m in models_paths]
print(reports_paths)

In [None]:
results = []
for i, r in enumerate(reports_paths):
  print(r)
  out = MODELS_FOLDER + str(i+1) + '/error_report_complete.json'
  res = generate_report(r, False, True, out)
  print(res)
  results.append(res)
  print('-----------------')

In [None]:
print(reports_paths[0])
print("PRECISION AVERAGE:", sum([float(r['precision']) for r in results]) / len(results))
print("RECALL AVERAGE:", sum([float(r['recall']) for r in results]) / len(results))
print("F1 AVERAGE:", sum([2*(r['precision']*r['recall'])/(r['precision'] + r['recall']) for r in results])/len(results))