In [10]:
# === Q6: Quantitative results (Dev) ===
import os
from utils import compute_metrics

# 路径：用你训练脚本默认写出的文件
GT_SQL_DEV      = "data/dev.sql"
GT_RECORDS_DEV  = "records/ground_truth_dev.pkl"
PRED_SQL_DEV    = "results/t5_ft_ft_experiment_dev.sql"
PRED_RECORD_DEV = "records/t5_ft_ft_experiment_dev.pkl"

assert os.path.exists(GT_SQL_DEV), GT_SQL_DEV
assert os.path.exists(GT_RECORDS_DEV), GT_RECORDS_DEV
assert os.path.exists(PRED_SQL_DEV), PRED_SQL_DEV
assert os.path.exists(PRED_RECORD_DEV), PRED_RECORD_DEV

sql_em, record_em, record_f1, error_msgs = compute_metrics(
    GT_SQL_DEV, PRED_SQL_DEV, GT_RECORDS_DEV, PRED_RECORD_DEV
)

# 打印为百分比、两位小数，便于直接放表
def pct(x): 
    return f"{x*100:.2f}"

print("=== Dev Metrics ===")
print(f"Query EM (Dev): {pct(sql_em)}")
print(f"Record F1 (Dev): {pct(record_f1)}")

print("\n=== Test Metrics ===")
print("Query EM (Test): N/A (fill from leaderboard)")
print("Record F1 (Test): N/A (fill from leaderboard)")


=== Dev Metrics ===
Query EM (Dev): 2.36
Record F1 (Dev): 72.42

=== Test Metrics ===
Query EM (Test): N/A (fill from leaderboard)
Record F1 (Test): N/A (fill from leaderboard)


In [9]:
# === Q6: Qualitative error analysis on Dev (robust: ignore records structure) ===
import os, re, pickle
from collections import Counter, defaultdict

GT_NL_DEV       = "data/dev.nl"
PRED_SQL_DEV    = "results/t5_ft_ft_experiment_ec_dev.sql"
PRED_RECORD_DEV = "records/t5_ft_ft_experiment_ec_dev.pkl"

assert os.path.exists(GT_NL_DEV) and os.path.exists(PRED_SQL_DEV) and os.path.exists(PRED_RECORD_DEV)

# 读 dev 自然语言与预测 SQL（逐行一一对应）
with open(GT_NL_DEV, "r", encoding="utf-8") as f:
    dev_nl = [ln.strip() for ln in f if ln.strip() != ""]
with open(PRED_SQL_DEV, "r", encoding="utf-8") as f:
    pred_sqls = [ln.strip() for ln in f if ln.strip() != ""]

# 读 (records, error_msgs)，此处只用 error_msgs 做分类，不依赖 records 的内部结构
with open(PRED_RECORD_DEV, "rb") as f:
    records, error_msgs = pickle.load(f)

TOTAL = len(error_msgs)
print(f"Total dev examples: {TOTAL}")

# 若长度不一致，做个安全对齐（以最短为准）
N = min(TOTAL, len(dev_nl), len(pred_sqls))
dev_nl   = dev_nl[:N]
pred_sqls= pred_sqls[:N]
error_msgs = error_msgs[:N]

# 错误类别正则（你可按需要增删）
BUCKETS = [
    ("Syntax error",       re.compile(r"syntax|parse", re.I)),
    ("No such table/col",  re.compile(r"no such (table|column)|does not exist|unknown column", re.I)),
    ("Ambiguous column",   re.compile(r"ambiguous", re.I)),
    ("Type mismatch",      re.compile(r"type mismatch|incompatible datatype|cannot cast", re.I)),
    ("Value/arith error",  re.compile(r"divide by|overflow|out of range|invalid value", re.I)),
    ("Permission/engine",  re.compile(r"permission|not allowed|engine|timeout", re.I)),
    ("Other exec error",   re.compile(r".+")),  # 兜底
]

bucket_counts = Counter()
bucket_examples = defaultdict(list)

for i, msg in enumerate(error_msgs):
    if not msg:  # 空字符串表示执行成功
        continue
    msg_lower = msg.lower()
    cat = None
    for name, pat in BUCKETS:
        if pat.search(msg_lower):
            cat = name
            break
    if cat is None:
        cat = "Other exec error"
    bucket_counts[cat] += 1

    # 每类收集最多3个示例（NL+PredSQL+ErrMsg）
    if len(bucket_examples[cat]) < 3:
        bucket_examples[cat].append({
            "nl_query": dev_nl[i][:200],
            "pred_sql": pred_sqls[i][:200],
            "error_msg": msg.strip()[:200],
        })

# 打印 COUNT/TOTAL + 示例
print("\n=== Qualitative Error Buckets (Dev) ===")
for name, _ in BUCKETS:
    cnt = bucket_counts[name]
    if cnt == 0:
        continue
    print(f"- {name}: {cnt}/{N}")
    for ex in bucket_examples[name]:
        print("  • Example:")
        print(f"    NL : {ex['nl_query']}")
        print(f"    SQL: {ex['pred_sql']}")
        print(f"    MSG: {ex['error_msg']}")
    print("")


Total dev examples: 466

=== Qualitative Error Buckets (Dev) ===
- Syntax error: 27/466
  • Example:
    NL : list all arrivals from any airport to baltimore on thursday morning arriving before 9am
    SQL: SELECT DISTINCT flight_1.flight_id FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2, days days_1, date_day date_day_1 WHERE flight_1
    MSG: OperationalError: near "flight_1": syntax error
  • Example:
    NL : please give me the flight times i would like to fly from boston to baltimore in the morning before 8
    SQL: SELECT DISTINCT flight_1.departure_time, flight_1.arrival_time FROM flight flight_1, airport_service airport_service_1, city city_1, airport_service airport_service_2, city city_2 WHERE flight_1.from_
    MSG: OperationalError: near "800": syntax error
  • Example:
    NL : i would like information on flights from pittsburgh to baltimore arriving in baltimore before 10am on thursday
    SQL: SELECT D