In [None]:
from process_sql import get_schema, Schema, get_sql
from evaluation import build_valid_col_units, rebuild_sql_val, rebuild_sql_col, build_foreign_key_map_from_json, eval_exec_match, Evaluator
import os, nltk
nltk.download('punkt')
table = "table.json"
kmaps = build_foreign_key_map_from_json(table)
db_dir = "database"

In [None]:
def evaluate(label_str, pred_str, label_dbname, db_dir, kmaps):
    all = 0
    execute = 0
    exact = 0
    for p, g, db_name in zip(pred_str, label_str, label_dbname):
        p_str = p
        g_str = g
        db = os.path.join(db_dir, db_name, db_name + ".sqlite")
        schema = Schema(get_schema(db)) # schema is a dict with table name as key and list of column names as value
        g_sql = get_sql(schema, g_str)
        all += 1.0

        try:
            p_sql = get_sql(schema, p_str)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
            "except": None,
            "from": {
                "conds": [],
                "table_units": []
            },
            "groupBy": [],
            "having": [],
            "intersect": None,
            "limit": None,
            "orderBy": [],
            "select": [
                False,
                []
            ],
            "union": None,
            "where": []
            }

        # rebuild sql for value evaluation
        kmap = kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
        if exec_score:
            corr += 1.0

        evaluator = Evaluator()
        exact_score = evaluator.eval_exact_match(p_sql, g_sql)
        if exact_score:
            exact += 1.0

    return {"execute":execute/all, "exact": exact/all}

In [None]:
# compute_metrics function should be like this:
'''
def compute_metrics(pred):
    labels_ids = pred.label_ids
    label_dbname = pred.label_dbname # list of names of databases
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    output = evaluate(label_str, pred_str, label_dbname, db_dir, kmaps)

    return {
        "execution_accuracy": round(output["execute"], 4),
        "exact_matching": round(output["exact"], 4),
    }
'''