In [3]:
# -*- coding: utf-8 -*-
"""
Windows 下使用：
1) 把这个脚本文件放到和 *.jsonl 相同的目录；
2) 在该目录执行:  python add_ctx_windows.py
3) 新文件会输出到 ./with_ctx/ 目录
"""

import os
import re
import json
from pathlib import Path

# ===== 复刻 Scala 里的 allProfiles 顺序（保持完全一致） =====
def make_all_profiles():
    bools = ["true", "false"]
    brThr = ["-1", "64MB", "128MB", "256MB"]
    shuf  = ["64", "128", "200", "400", "800"]
    profiles = []
    for aqe in bools:
        for cbo in bools:
            for smj in bools:
                for bt in brThr:
                    for sp in shuf:
                        kv = {
                            "spark.sql.adaptive.enabled": aqe,
                            "spark.sql.cbo.enabled": cbo,
                            "spark.sql.cbo.joinReorder.enabled": cbo,
                            "spark.sql.join.preferSortMergeJoin": smj,
                            "spark.sql.autoBroadcastJoinThreshold": bt,
                            "spark.sql.shuffle.partitions": sp,
                        }
                        profiles.append((kv))
    return profiles[:50]

ALL_PROFILES = make_all_profiles()
Q = 103  # TPCDS v2.4 查询条数（根据你的环境保持一致）

def idx_to_pq(N: int, Q: int):
    """根据全局编号 N（1-based 文件名）还原到 (pIdx, qIdx)，也都是 1-based。"""
    pIdx = ((N - 1) // Q) + 1
    qIdx = ((N - 1) %  Q) + 1
    return pIdx, qIdx

def as_ctx_dict(kv: dict) -> dict:
    """保持与 Scala 侧一致的字符串值；训练侧再做 0/1、log1p 等数值化。"""
    return {
        "spark.sql.adaptive.enabled": kv["spark.sql.adaptive.enabled"],
        "spark.sql.cbo.enabled": kv["spark.sql.cbo.enabled"],
        "spark.sql.join.preferSortMergeJoin": kv["spark.sql.join.preferSortMergeJoin"],
        "spark.sql.autoBroadcastJoinThreshold": kv["spark.sql.autoBroadcastJoinThreshold"],
        "spark.sql.shuffle.partitions": kv["spark.sql.shuffle.partitions"],
    }

def main():
    # 使用当前工作目录作为 jsonl 目录
    src_dir = Path(os.getcwd())
    dst_dir = src_dir / "with_ctx"
    dst_dir.mkdir(parents=True, exist_ok=True)

    num_re = re.compile(r"^(\d+)\.jsonl$", re.IGNORECASE)

    files = []
    for p in src_dir.iterdir():
        if p.is_file():
            m = num_re.match(p.name)
            if m:
                files.append((int(m.group(1)), p))
    files.sort(key=lambda t: t[0])  # 按编号升序

    print(f"[info] 发现 {len(files)} 个 jsonl 文件（形如 N.jsonl）")

    for N, fp in files:
        pIdx, qIdx = idx_to_pq(N, Q)
        if not (1 <= pIdx <= len(ALL_PROFILES)):
            print(f"[warn] N={N} -> pIdx={pIdx} 超出 profiles 范围（len={len(ALL_PROFILES)}），跳过该文件")
            continue

        kv = ALL_PROFILES[pIdx - 1]
        out_fp = dst_dir / fp.name

        with open(fp, "r", encoding="utf-8", newline="") as fin, \
             open(out_fp, "w", encoding="utf-8", newline="") as fout:
            for line in fin:
                line = line.strip()
                if not line:
                    continue
                obj = json.loads(line)
                if "ctx" not in obj:
                    obj["ctx"] = as_ctx_dict(kv)
                fout.write(json.dumps(obj, ensure_ascii=False) + "\n")

        print(f"[ok] {fp.name} -> 输出到 {out_fp}")

    print(f"\n[done] 所有文件已处理。新文件在 {dst_dir}")

if __name__ == "__main__":
    main()


[info] 发现 244 个 jsonl 文件（形如 N.jsonl）
[ok] 1.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\1.jsonl
[ok] 2.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\2.jsonl
[ok] 3.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\3.jsonl
[ok] 4.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\4.jsonl
[ok] 5.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\5.jsonl
[ok] 6.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\6.jsonl
[ok] 7.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\7.jsonl
[ok] 8.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\8.jsonl
[ok] 9.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\9.jsonl
[ok] 10.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\10.jsonl
[ok] 11.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\11.jsonl
[ok] 12.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\12.jsonl
[ok] 13.jsonl -> 输出到 C:\Users\huyue\Desktop\Result\train\with_ctx\13.jsonl
[ok] 1