In [None]:
def same_dict_schema(d1: dict, d2: dict) -> bool:
    if d1.keys() != d2.keys():
        return False
    for k in d1:
        v1, v2 = d1[k], d2[k]
        if isinstance(v1, dict) and isinstance(v2, dict):
            if not same_dict_schema(v1, v2):
                return False
        else:
            if type(v1) is not type(v2):
                return False
    return True


# 示例
a = {"id": 1, "info": {"name": "Tom", "age": 20}}
b = {"id": 999, "info": {"name": "Alice", "age": 5}}
c = {"info": {"name": "Tom", "age": 1}, "id": 1}  # age 类型不同

print(same_dict_schema(a, b))  # True
print(same_dict_schema(a, c))  # False

In [None]:
import re
from typing import List, Tuple

COMMON_STOP_TOKENS = {"请", "帮", "帮我", "一下", "的", "请问"}


def _clean(text: str) -> str:
    # 去掉全角空格与常见标点
    return re.sub(r"[，。,！？?；;：:\s]", "", text)


def _lcs_indices(a: str, b: str) -> set:
    """返回 a 与 b 的一个 LCS 中 a 的字符索引集合（不唯一，取任一即可）"""
    la, lb = len(a), len(b)
    dp = [[0]*(lb+1) for _ in range(la+1)]
    for i in range(la):
        for j in range(lb):
            if a[i] == b[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    # 回溯
    i, j = la, lb
    indices = set()
    while i > 0 and j > 0:
        if a[i-1] == b[j-1]:
            indices.add(i-1)
            i -= 1
            j -= 1
        else:
            if dp[i-1][j] >= dp[i][j-1]:
                i -= 1
            else:
                j -= 1
    return indices


def extract_unique_phrase(main_query: str, alt_queries: List[str]) -> str:
    main = _clean(main_query)
    if not alt_queries:
        return main
    alt_cleaned = [_clean(q) for q in alt_queries]

    # 计算每个备选的 LCS 覆盖索引，再取交集（模板位置）
    lcs_sets = [_lcs_indices(main, alt) for alt in alt_cleaned]
    template_positions = set.intersection(*lcs_sets) if lcs_sets else set()

    # 非模板位置就是候选
    variable_positions = [i for i in range(
        len(main)) if i not in template_positions]

    if not variable_positions:
        return ""

    # 合并连续片段
    segments: List[Tuple[int, int]] = []
    start = variable_positions[0]
    prev = start
    for pos in variable_positions[1:]:
        if pos == prev + 1:
            prev = pos
        else:
            segments.append((start, prev + 1))
            start = pos
            prev = pos
    segments.append((start, prev + 1))

    # 取中文字符片段并做停用词/噪声过滤
    def valid_seg(s: str) -> bool:
        if not s:
            return False
        if all(ch in COMMON_STOP_TOKENS for ch in s):
            return False
        # 至少含一个汉字
        return bool(re.search(r"[\u4e00-\u9fff]", s))

    candidates = []
    for a, b in segments:
        seg = main[a:b]
        # 去掉首尾常见功能字
        seg = re.sub(r"^(请|帮我|帮|一下|的)+", "", seg)
        seg = re.sub(r"(请|帮我|帮|一下|的)+$", "", seg)
        if valid_seg(seg):
            candidates.append(seg)

    if not candidates:
        return ""

    # 返回最长（或可再加权）
    candidates.sort(key=lambda x: len(x), reverse=True)
    return candidates


if __name__ == "__main__":
    main_q = "包含朱一龙燕之屋关键词的笔记词云表现，近30天"
    alts = ["包含路觅教育关键词的笔记词云表现，近30天",
            "包含法式床关键词的笔记词云表现，近30天", "查询时间为近90天，关键词为 茅台 的热搜词"]
    print(extract_unique_phrase(main_q, alts))  # 期望输出: 雅诗兰黛

In [None]:
# ...existing code...
import jieba
import re
from typing import List, Tuple,Dict,Set
from collections import Counter
from math import log
# ================== 模板停用词（基础集合） ==================
TEMPLATE_STOP = {
    # 基础功能
    "请", "帮", "帮我", "一下", "的", "请问", "需要", "我想查询", "我想要", "我想要抓取", "我想看", "拉取", "查询", "查看", "统计", "导出", "输出", "提供", "求和", "分析", "加以分析",
    # 查询结构 & 条件
    "包含", "模糊匹配", "筛选", "筛选条件", "条件", "关键词", "关键词为", "关键词包含", "搜索词", "搜索词内容", "竞价词", "品牌", "名称", "名称和", "名称及", "各", "以及", "对比", "分组", "分组展示", "维度", "分搜索词维度", "指标",
    # 时间/范围提示词
    "时间", "时间为", "时间是", "查询时间", "时间范围", "时间周期", "期间", "至今", "到现在", "全年", "月底", "年初", "月份", "月份的", "月份内", "范围",
    # 时间单位（碎片防拼接）
    "年", "月", "日", "号", "天", "季度",
    # 相对时间
    "昨日", "昨天", "今日", "今天", "本月", "上个月", "今年", "去年", "近", "近半年", "近一个月", "近两个月", "近三个月", "近六个月", "近12个月", "近1个月", "近一年来", "近一年",
    # 趋势 & 图表
    "趋势", "趋势图", "指数", "指数趋势", "指数趋势图", "搜索趋势", "搜索趋势图", "折线", "折线图", "折线趋势", "折线趋势图", "榜单", "排名", "TOP", "TOP3",
    # 搜索指标
    "搜索", "搜索量", "搜索次数", "搜索人数", "搜索指数", "搜索指数趋势", "搜索指数趋势图",
    # 热搜 / 上下游
    "热搜词", "热搜词有哪些", "热搜词都有哪些", "相关", "相关的", "相关热搜词",
    "上下游词", "上下游搜索词", "上下游词数据表现", "上下游搜索词数据表现", "数据表现", "数据",
    # SOV / SOC
    "sov", "soc", "SOV", "SOC", "值", "排名",
    # 内容/笔记
    "笔记", "热门笔记", "笔记数", "有效笔记", "有效", "新增", "数量", "数量趋势", "标题", "正文", "标签",
    "阅读", "阅读量", "曝光", "曝光数", "曝光量", "点击", "点击数", "点击量", "互动", "互动量", "点赞", "收藏", "评论",
    "累计曝光", "累计阅读", "累计点赞", "累计收藏", "累计评论",
    "日均", "按日均", "每日", "每天", "按月", "按月份", "每个月", "每一个月", "分月", "按月统计", "按月份统计", "by月", "byday", "按天",
    # 账号 / 作者 / 店铺 / 品牌
    "账号", "账号名称", "账号ID", "作者", "作者ID", "店铺", "店铺gmv", "gmv", "GMV", "品牌搜索", "品牌搜索趋势",
    # AIPS
    "aips", "资产", "规模", "消耗", "新增aips", "人群", "人群资产", "人群资产规模",
    # 其他结构
    "输出", "导出", "统计", "查询", "需要", "提供", "隐去", "只保留", "只给", "加以", "分析",
    # 去噪单字（必要时可扩展）
    "图", "月", "日", "天",
    # 词云
    "词云表现", "词云"
}

# 可选：若想彻底避免 “搜索量” 被拆成 “搜索” + “量” 后留下 “量”，可开启下面单字停用
SINGLE_STOP_EXTRA = {"次"}  # 按需加入/删除
TEMPLATE_STOP |= SINGLE_STOP_EXTRA

# 需要强制保持整体的多字模板短语（越长越靠前）
MANDATORY_LONG_TEMPLATE_PHRASES = {
    "搜索指数趋势图", "搜索指数趋势", "搜索指数", "搜索量", "搜索次数", "搜索人数",
    "上下游搜索词数据表现", "上下游词数据表现", "上下游搜索词",
    "品牌搜索趋势", "品牌搜索", "数量趋势", "指数趋势图", "指数趋势",
    "脱敏社区搜索量", "有效笔记", "热门笔记"
}

# 初始化：把多字模板词按长度降序加到 jieba 词典，避免拆分


def init_template_phrases():
    long_phrases = sorted(
        {w for w in (TEMPLATE_STOP | MANDATORY_LONG_TEMPLATE_PHRASES)
         if len(w) > 1},
        key=lambda x: len(x),
        reverse=True
    )
    for w in long_phrases:
        # 设置极高频率强制不再拆
        jieba.add_word(w, freq=10**8, tag="tpl")
    # 避免单字被过度误杀的情况，可视需要移除 SINGLE_STOP_EXTRA 中某些字符
    # print("Loaded template phrases:", long_phrases[:10], "... total:", len(long_phrases))


# 在模块加载时初始化
init_template_phrases()

# ================= 时间正则（修正前两条拼接错误） =================
TIME_SPAN_PATTERNS = [
    # 1. 年月日 到 年月日（后端可缺年份）
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月\d{1,2}[日号]?\s*[到至\-—~～]\s*(?:(?:20\d{2}|19\d{2}|[0-9]{2})年)?(?:\d{1,2}月)?\d{1,2}[日号]?',
    # 2. 年月 到 年月（后段可缺年份）
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月\s*[到至\-—~～]\s*(?:(?:20\d{2}|19\d{2}|[0-9]{2})年)?\d{1,2}月',
    # 3. 年月日到现在/至今
    r'(?:20\d{2}|19\d{2})年\d{1,2}月\d{1,2}[日号]?到(?:现在|至今)',
    r'(?:20\d{2}|19\d{2})年\d{1,2}月到(?:现在|至今)',
    r'今年\d{1,2}月份?到(?:现在|至今)',
    r'从(?:20\d{2}|19\d{2})年\d{1,2}月\d{1,2}[日号]?到(?:现在|至今)',
    r'从今年\d{1,2}月份?到(?:现在|至今)',
    # 4. 简写年份跨月日
    r'(?:[0-9]{2})年\d{1,2}月\d{1,2}[日号]?[\-—~～到至]\d{1,2}月\d{1,2}[日号]?',
    # 5. 年-月范围
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月[\-—~～到至]\d{1,2}月',
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}[\-—~～到至]\d{1,2}月',
    # 6. 年x月份
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月份?',
    # 7. 年月日
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月\d{1,2}[日号]',
    # 8. 年月
    r'(?:20\d{2}|19\d{2}|[0-9]{2})年\d{1,2}月',
    # 9. 纯数字连续范围
    r'20\d{6}[到至\-—~～]20\d{6}',
    r'20\d{6}[到至\-—~～]\d{4}',
    r'20\d{6}至\d{6}',
    # 10. 点号日期范围
    r'(?:20\d{2}|19\d{2})[\.年]\d{1,2}\.\d{1,2}[日号]?[\-—~～到至]\d{1,2}\.\d{1,2}[日号]?',
    r'(?:20\d{2}|19\d{2})\.\d{1,2}\.?[\-—~～到至](?:20\d{2}|19\d{2})?\.\d{1,2}',
    # 11. 年+季度
    r'(?:20\d{2}|19\d{2})年Q[1-4][到至\-—~～](?:20\d{2}|19\d{2})?年?Q[1-4]',
    # 12. 年范围 / 单年
    r'(?:20\d{2}|19\d{2})年[到至\-—~～](?:20\d{2}|19\d{2})年',
    r'(?:20\d{2}|19\d{2})年',
    # 13. 月日-月日
    r'\d{1,2}月\d{1,2}[日号]?[\-—~～到至]\d{1,2}月\d{1,2}[日号]?',
    r'\(\d{1,2}月\d{1,2}[日号]?[\-—~～到至]\d{1,2}月\d{1,2}[日号]?\)',
    r'半年\(\d{1,2}月\d{1,2}[日号]?[\-—~～到至]\d{1,2}月\d{1,2}[日号]?\)',
    # 15. 单日（无年）
    r'\d{1,2}月\d{1,2}[日号]',
    # 16. 月-月
    r'\d{1,2}月[\-—~～到至]\d{1,2}月',
    r'\d{1,2}[\-—~～到至]\d{1,2}月',
    # 17. 月底 / 年初
    r'(?:20\d{2}|19\d{2})年\d{1,2}月底',
    r'(?:20\d{2}|19\d{2})年年初',
    # 18. 相对时间（宽）
    r'(?:近|最近|过去)\s*[一二三四五六七八九十百千万\d]+\s*(?:天|日|周|个月|月|季度|年|年来)',
    r'近\s*[一二三四五六七八九十]+\s*年(?:来)?',
    # 19. 相对天数
    r'近\s*\d+\s*日',
    r'近\s*\d+\s*天',
    r'最近\s*\d+\s*天',
    r'最近\s*\d+\s*日',
    # 20. 单词相对
    r'(?:今日|昨天|昨日|上个月|本月|今年|去年|近半年|近一个月|近两个月|近三个月|近六个月|近12个月)',
    # 21. 至今
    r'至今',
    r'到现在',
    # 22. 全年
    r'(?:20\d{2}|19\d{2})年全年',
    # 23. 月号到月号
    r'\d{1,2}月\d{1,2}号[到至\-—~～]\d{1,2}月\d{1,2}[号日]',
    # 24. 年+起止（冗余补充）
    r'(?:20\d{2}|19\d{2})年\d{1,2}月\d{1,2}[日号]?[\-—~～到至]\d{1,2}月\d{1,2}[日号]?',
]

TIME_SPAN_REGEXES = [re.compile(p) for p in TIME_SPAN_PATTERNS]

RELATIVE_TIME_TOKEN = re.compile(
    r'^(?:近|最近|过去)[一二三四五六七八九十百千万\d]+(?:天|日|周|个月|月|季度|年|年来)?$'
)

def _find_time_spans(text: str) -> List[Tuple[int,int]]:
    matches = []
    for rx in TIME_SPAN_REGEXES:
        for m in rx.finditer(text):
            matches.append((m.start(), m.end()))
    if not matches:
        return []
    # 合并重叠，优先长
    matches.sort(key=lambda x:(x[0], - (x[1]-x[0])))
    merged = []
    cur_s, cur_e = matches[0]
    for s,e in matches[1:]:
        if s <= cur_e:  # overlap
            if e > cur_e:
                cur_e = e
        else:
            merged.append((cur_s, cur_e))
            cur_s, cur_e = s,e
    merged.append((cur_s, cur_e))
    return merged

def _mask_time_spans(text: str) -> str:
    spans = _find_time_spans(text)
    if not spans:
        return text
    pieces = []
    last = 0
    for idx,(s,e) in enumerate(spans):
        if s > last:
            pieces.append(text[last:s])
        pieces.append(f'T_TIME{idx}')  # 占位
        last = e
    if last < len(text):
        pieces.append(text[last:])
    return ''.join(pieces)

def _is_time_token(tok: str) -> bool:
    if tok.startswith('T_TIME'):
        return True
    if RELATIVE_TIME_TOKEN.fullmatch(tok):
        return True
    return False

SEP_PATTERN = re.compile(r"[、，,、/\s]+")

DIGIT_PATTERN = re.compile(r"^\d+$")

SLOT_PATTERNS = [
    re.compile(r"包含(?P<slot>.+?)关键词"),
    re.compile(r"关键词为\s*(?P<slot>[\u4e00-\u9fffA-Za-z0-9 ]+)"),
    re.compile(r"关键词包含(?P<slot>.+?)的"),
]

def _normalize_text(q: str) -> str:
    return re.sub(r"[，。,！？?；;：:\s]", "", q)


def _segment(q: str) -> List[str]:
    """
    1. 先时间占位 mask
    2. 去除常规标点（新增去掉 '、'）
    3. jieba 分词
    4. 拆分被粘连的占位符：例如 socT_TIME0 -> soc + T_TIME0
    5. 去除纯标点/空白
    """
    masked = _mask_time_spans(q)
    cleaned = re.sub(r"[，。,！？?；;：:\s、]", "", masked)
    toks = [w for w in jieba.lcut(cleaned) if w.strip()]
    new = []
    for t in toks:
        if "T_TIME" in t and not t.startswith("T_TIME"):
            # 拆分占位符
            parts = re.split(r'(T_TIME\d+)', t)
            for p in parts:
                if not p:
                    continue
                new.append(p)
        else:
            new.append(t)
    # 再次把仍含 T_TIME 且前后粘连的情况切干净（保险）
    final = []
    for t in new:
        if "T_TIME" in t and t != re.findall(r'T_TIME\d+', t)[0]:
            # 若还有粘连（极少），强制剥离所有占位符
            buf = re.split(r'(T_TIME\d+)', t)
            for b in buf:
                if b:
                    final.append(b)
        else:
            final.append(t)
    # 去掉孤立的 '、'
    return [w for w in final if w != "、"]

def _extract_slot_candidates(q: str) -> List[str]:
    candidates = []
    for pat in SLOT_PATTERNS:
        for m in pat.finditer(q):
            raw = m.group("slot").strip()
            if not raw:
                continue
            raw = raw.strip("“”\"' ")
            parts = [p for p in SEP_PATTERN.split(raw) if p]
            if parts:
                candidates.append("".join(parts))
    return candidates

def extract_unique_phrase(main_query: str,
                          alt_queries: List[str],
                          jaccard_threshold: float = 0.15,
                          majority_threshold: float = 0.6,
                          return_all: bool = True) -> List[str]:
    """
    时间片段预掩码 + 多数模板判定 + 槽位融合
    """
    if not alt_queries:
        return [main_query]

    slot_main = _extract_slot_candidates(main_query)

    main_tokens = _segment(main_query)
    alt_tokens_list = [_segment(q) for q in alt_queries]

    main_set = set(main_tokens)

    # 过滤相似度过低的备选
    valid_alts = []
    for toks in alt_tokens_list:
        s = set(toks)
        jac = len(main_set & s) / (len(main_set | s) or 1)
        if jac >= jaccard_threshold:
            valid_alts.append(toks)

    if not valid_alts:
        if slot_main:
            return slot_main
        filtered = [t for t in main_tokens if t not in TEMPLATE_STOP and not _is_time_token(t)]
        return ["".join(filtered)] if filtered else [""]

    # 统计 DF
    df: Dict[str,int] = {}
    for toks in valid_alts:
        for w in set(toks):
            df[w] = df.get(w,0) + 1
    alt_count = len(valid_alts)

    # 判定模板 token
    template_tokens: Set[str] = set()
    for w in main_tokens:
        freq = df.get(w,0)/alt_count
        if (freq >= majority_threshold) or (w in TEMPLATE_STOP) or _is_time_token(w):
            template_tokens.add(w)

    # 标记变量
    variable_flags = [w not in template_tokens for w in main_tokens]

    # 合并连续变量
    segments = []
    cur=[]
    for flag,tok in zip(variable_flags, main_tokens):
        if flag:
            cur.append(tok)
        else:
            if cur:
                segments.append(cur)
                cur=[]
    if cur:
        segments.append(cur)

    candidates=[]
    for seg in segments:
        clean = [
            w for w in seg if w not in TEMPLATE_STOP and not _is_time_token(w)]
        clean = [w for w in clean if not DIGIT_PATTERN.match(w)]
        if not clean:
            continue
        phrase = "".join(clean)
        # 去掉首尾时间残余边界符
        phrase = re.sub(r'^[日至号到至\-—~～、]+', '', phrase)
        phrase = re.sub(r'[日至号到至\-—~～、]+$', '', phrase)
        # 去掉嵌入的占位符或其截断残片
        phrase = re.sub(r'T_TIME\d+', '', phrase)
        # 去掉剥离后可能残留的单独 'T'
        phrase = re.sub(r'^T+$', '', phrase)
        # 再次清理可能产生的前后符号
        phrase = re.sub(r'^[日至号到至\-—~～、]+', '', phrase)
        phrase = re.sub(r'[日至号到至\-—~～、]+$', '', phrase)
        if not phrase:
            continue
        if len(phrase) == 1 and phrase in TEMPLATE_STOP:
            continue
        candidates.append(phrase)

    # 融合槽位
    for s in slot_main:
        if s and s not in candidates:
            candidates.append(s)

    # 去伪变量：在多数 alt 中完整出现的剔除
    filtered=[]
    for c in candidates:
        appear = sum(1 for q in alt_queries if c in q)
        if appear >= alt_count*0.6:
            continue
        filtered.append(c)

    def score(c: str):
        in_slot = 1 if c in slot_main else 0
        appear = sum(1 for q in alt_queries if c in q)
        uniqueness = -appear
        return (in_slot, len(c), uniqueness)

    final=[]
    seen=set()
    for c in sorted(filtered, key=score, reverse=True):
        if c not in seen:
            seen.add(c)
            final.append(c)

    return final if return_all else (final[:1] or [""])


# 测试
if __name__ == "__main__":
    tests = [
        ("包含朱一龙燕之屋关键词的笔记词云表现，近30天", [
            "包含路觅教育关键词的笔记词云表现，近30天",
            "包含法式床关键词的笔记词云表现，近30天",
            "查询时间为近90天，关键词为 茅台 的热搜词",
            "儿童液体钙类目下，包含液体钙、ad关键词的热门笔记，近30天",
            "近一年来，用户搜索关键词模糊匹配“泡泡玛特股票”的搜索量，需要每一个月的"
        ]),
        ("包含VIPKID关键词的笔记词云表现，近30天", [
            "包含路觅教育关键词的笔记词云表现，近30天",
            "包含法式床关键词的笔记词云表现，近30天"
        ]),
        ("包含HOLLISTER关键词的笔记词云表现，近8个月", [
            "包含路觅教育关键词的笔记词云表现，近30天",
            "查询时间为近90天，关键词包含 茅台1935 的热搜词"
        ]),
    ]
    for mq, alts in tests:
        print(mq, "->", extract_unique_phrase(mq, alts))

if __name__ == "__main__":
    main_q = "包含朱一龙燕之屋关键词的笔记词云表现，近30天"
    alts = [
        "包含路觅教育关键词的笔记词云表现，近30天",
        "包含法式床关键词的笔记词云表现，近30天",
        "查询时间为近90天，关键词为 朱一龙燕之屋 的热搜词"
    ]
    print(extract_unique_phrase(main_q, alts))  # 期望: ['朱一龙燕之屋']
# ...existing code...

In [None]:
import json
target_file_url = '../log/20250919/144659_only_retrieved.json'
with open(target_file_url, 'r', encoding='utf-8') as f:
    data = json.load(f)
len(data)
target_key = [key for key in data.keys() if 'entity' in key][0]
data = data[target_key]
retrieved = data["retrieved"]
ct = {}
for retrieved_item in retrieved:
    pattern = retrieved_item["pattern"]
    query = pattern["query"]
    pattern_retrieved = pattern["retrieved"]
    retrieved_query = [pr[0]['query'] for pr in pattern_retrieved]
    ct[query] = retrieved_query

In [None]:
for main_q, alts in list(ct.items()):
    print("main_q:", main_q, "alts:", alts)
    print("extracted:", extract_unique_phrase(main_q, alts))
    print()

In [None]:
import json
template_file_url = '../data/20250916/template_AIO_0.95_4.json'
test_file_url = '../data/20250916/test_AIO_0.95_4.json'
with open(template_file_url, 'r', encoding='utf-8') as f:
    templates = json.load(f)
with open(test_file_url, 'r', encoding='utf-8') as f:
    tests = json.load(f)
len(templates), len(tests)
template_query = [t["query"] for t in templates]
test_query = [t["query"] for t in tests]
querys = template_query + test_query

querys_file_url = '../data/20250916/template_test_querys.txt'
with open(querys_file_url, 'w', encoding='utf-8') as f:
    f.writelines("\n".join(querys))

In [None]:
import pandas as pd
df = pd.read_parquet("../data/ignore/20250916_AIO_correct_erased_embedding.parquet")
df.columns

## 处理spider相关的错误

In [None]:
import json
target_file = "../data/spider_dsl/test/spider_sim_value.json"
with open(target_file, "r") as f:
    data = json.load(f)
len(data)
# 569 503

In [None]:
for item in data:
    assert item.get("id"), f"Missing id in item {item['question']}"

In [None]:

for item in data:
    id = item['config'].get("id") or item.get("id")
    item["id"] = int(id)+569+569
    print(id, item)
    item["table_name"] = item['config'].get("table_name") or item.get("table_name")
    item["config"] = {
        "dimension": item['config'].get("dimension", []),
        "measure": item['config'].get("measure", []),
        "filter": item['config'].get("filter", [])
    }

In [None]:
with open("../data/spider_dsl/test/spider_sim_value_clean.json", "w", encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

In [None]:
files = ["../data/spider_dsl/test/spider_sim_value_clean.json",
         "../data/spider_dsl/test/spider_sim_column_clean.json",
         "../data/spider_dsl/test/spider_sim_question_clean.json"]

In [None]:
import json
data = sum([json.load(open(file, "r", encoding='utf-8')) for file in files], [])

with open("../data/spider_dsl/test/spider_sim_merged.json", "w", encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

## 处理bird相关的文件

In [None]:
import json
target_file = "../data/bird_dsl/test/bird_sim_question.json"
with open(target_file, "r") as f:
    data = json.load(f)
len(data)
# 569 503

In [None]:
for item in data:
    assert item.get("id"), f"Missing id in item {item['question']}"

In [None]:

for item in data:
    id = item['config'].get("id") or item.get("id")
    item["id"] = int(id)+353*3
    # print(id, item)
    item["table_name"] = item['config'].get(
        "table_name") or item.get("table_name")
    item["config"] = {
        "dimension": item['config'].get("dimension", []),
        "measure": item['config'].get("measure", []),
        "filter": item['config'].get("filter", [])
    }

In [None]:
with open(target_file.rpartition('.')[0] + '_clean.json', "w", encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=2)