<a href="https://colab.research.google.com/github/abdmomin/text2sql/blob/main/evaluation_for_sql.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -Uq sqlglot

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m461.8/461.8 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ibis-framework 9.5.0 requires sqlglot<25.21,>=23.4, but you have sqlglot 26.18.1 which is incompatible.[0m[31m
[0m

In [None]:
import sqlglot
from sqlglot.expressions import Select


def extract_sql_components(sql: str) -> dict:
    try:
        parsed = sqlglot.parse_one(sql)
    except Exception as e:
        return {
            "columns": set(),
            "tables": set(),
            "keywords": set(),
            "error": str(e),
        }

    columns: set[str] = set()
    tables: set[str] = set()
    keywords: set[str] = set()

    for column in parsed.find_all(sqlglot.expressions.Column):
        if column.name:
            columns.add(column.name.lower())

    for table in parsed.find_all(sqlglot.expressions.Table):
        if table.name:
            tables.add(table.name.lower())

    if isinstance(parsed, Select):
        if parsed.args.get("where"):
            keywords.add("where")
        if parsed.args.get("group"):
            keywords.add("group")
        if parsed.args.get("order"):
            keywords.add("order")
        if parsed.args.get("having"):
            keywords.add("having")
        if parsed.args.get("limit"):
            keywords.add("limit")
        if parsed.args.get("joins"):
            keywords.add("join")

    return {
        "columns": columns,
        "tables": tables,
        "keywords": keywords,
    }


def jaccard_similarity(set1: set[str], set2: set[str]) -> float:
    return len(set1 & set2) / len(set1 | set2)


def compute_sql_partial_match(predicted_sql: str, gold_sql: str) -> dict:
    pred = extract_sql_components(predicted_sql)
    gold = extract_sql_components(gold_sql)

    col_score = jaccard_similarity(pred["columns"], gold["columns"])
    table_score = jaccard_similarity(pred["tables"], gold["tables"])
    keyword_score = jaccard_similarity(pred["keywords"], gold["keywords"])

    total_score = 0.5 * col_score + 0.3 * table_score + 0.2 * keyword_score

    return {
        "column_score": col_score,
        "table_score": table_score,
        "keyword_score": keyword_score,
        "total_score": total_score
    }

In [None]:
import re

def normalize_sql(sql: str) -> list[str]:
  sql = sql.lower()
  # Remove symbols for cleaner splitting
  sql = re.sub(r'[(),;<>=]', '', sql)
  return sql.split()


def compute_f1(prediction: str|dict[str, float], truth: str|dict[str, float]) -> int | float:
  if isinstance(prediction, str) and isinstance(truth, str):
    pred_tokens = normalize_sql(prediction)
    truth_tokens = normalize_sql(truth)

  pred_tokens = prediction
  truth_tokens = truth

  # if either the prediction or the truth has no-answer then f1 = 1 if they agree, 0 otherwise
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)

  common_tokens = set(pred_tokens) & set(truth_tokens)

  # if there are no common tokens then f1 = 0
  if len(common_tokens) == 0:
    return 0

  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)

  return 2 * (prec * rec) / (prec + rec)

In [None]:
def f1_score(pred_sql, gold_sql):
  pred = extract_sql_components(pred_sql)
  gold = extract_sql_components(gold_sql)

  column_f1 = compute_f1(pred["columns"], gold["columns"])
  table_f1 = compute_f1(pred["tables"], gold["tables"])
  keyword_f1 = compute_f1(pred["keywords"], gold["keywords"])

  # weighted_f1 = 0.5 * column_f1 + 0.3 * table_f1 + 0.2 * keyword_f1 # weighted f1 score
  macro_f1 = (column_f1 + table_f1 + keyword_f1) / 3 # macro f1 score
  return {
      "column_f1": column_f1,
      "table_f1": table_f1,
      "keyword_f1": keyword_f1,
      "macro_f1": macro_f1,
      # "weighted_f1": weighted_f1
  }

In [None]:
# test
predicted_sql = "SELECT name, price FROM products WHERE price > 100"
gold_sql = "SELECT price, name FROM products WHERE price >= 100"

compute_sql_partial_match(predicted_sql, gold_sql)

{'column_score': 1.0,
 'table_score': 1.0,
 'keyword_score': 1.0,
 'total_score': 1.0}

In [None]:
extract_sql_components(predicted_sql), extract_sql_components(gold_sql)

({'columns': {'name', 'price'}, 'tables': {'products'}, 'keywords': {'where'}},
 {'columns': {'name', 'price'}, 'tables': {'products'}, 'keywords': {'where'}})

In [None]:
compute_f1(predicted_sql, gold_sql)

0.5742574257425743

In [None]:
f1_score(predicted_sql, gold_sql)

{'column_f1': 1.0, 'table_f1': 1.0, 'keyword_f1': 1.0, 'macro_f1': 1.0}

In [None]:
# test
pred = "SELECT name, email FROM customers WHERE name='John'"
gold = "SELECT price, name FROM products WHERE price >= 100"

compute_f1(pred, gold)

0.49019607843137253

In [None]:
compute_sql_partial_match(pred, gold)

{'column_score': 0.3333333333333333,
 'table_score': 0.0,
 'keyword_score': 1.0,
 'total_score': 0.3666666666666667}

In [None]:
f1_score(pred, gold)

{'column_f1': 0.5, 'table_f1': 0, 'keyword_f1': 1.0, 'macro_f1': 0.5}

In [None]:
extract_sql_components(pred), extract_sql_components(gold)

({'columns': {'email', 'name'},
  'tables': {'customers'},
  'keywords': {'where'}},
 {'columns': {'name', 'price'}, 'tables': {'products'}, 'keywords': {'where'}})