In [None]:
from glob import glob

glob("../data/*")

In [None]:
import re
from datetime import datetime, timedelta
from collections import defaultdict
from typing import List, Dict, Any, Tuple

# TODO: 使用一套更完备的 train / val/ test 分割技术来验证模型的有效性

DATE_RE = re.compile(r"_(\d{8})\.(?:parquet|csv)$")

def parse_date(p: str):
    m = DATE_RE.search(p)
    return datetime.strptime(m.group(1), "%Y%m%d").date() if m else None

def month_key(d):
    return (d.year, d.month)

def rolling_month_splits(
    paths: List[str],
    train_months: int = 4,
    val_months: int = 1,
    test_months: int = 1,
    purge_days: int = 0,      # 典型 1~5，根据你 label 前视期设置
    warmup_days: int = 0      # 典型 1~3/5，特征需要回看/EMA 的话启用
) -> List[Dict[str, Any]]:
    """
    生成按月份滚动的 (train, val, test) 路径列表。
    每个 fold: 训练 = 连续4个月，验证 = 下一月，测试 = 再下一月。
    """
    # 1) 解析日期
    dated = [(p, parse_date(p)) for p in paths]
    dated = [(p, d) for p, d in dated if d is not None]
    if not dated:
        return []
    dated.sort(key=lambda x: x[1])

    # 2) 建 day->paths 索引，和 month 列表
    by_date = defaultdict(list)
    for p, d in dated:
        by_date[d].append(p)

    months = []
    seen = set()
    for _, d in dated:
        mk = month_key(d)
        if mk not in seen:
            seen.add(mk); months.append(mk)

    # 3) 按月聚合出每月的“有数据的日期集合”
    month_to_days = defaultdict(list)
    for d in sorted(by_date.keys()):
        month_to_days[month_key(d)].append(d)

    # 4) 生成滚动折
    total_m = len(months)
    need = train_months + val_months + test_months  # 4+1+1=6
    folds = []

    # 测试月的索引 i：用 months[i] 作为 test 月；i 从 need-1 开始
    for i in range(need - 1, total_m):
        test_m = months[i]
        val_m  = months[i - test_months]
        train_ms = months[i - test_months - val_months - train_months + 1 : i - test_months - val_months + 1]

        # 5) 取出各自的日期集合
        train_days = set()
        for m in train_ms:
            train_days.update(month_to_days[m])

        val_days = sorted(month_to_days[val_m])
        test_days = sorted(month_to_days[test_m])

        if not val_days or not test_days or not train_days:
            continue

        # 6) 特征暖启动：去掉 val/test 月前 warmup_days 天
        if warmup_days > 0 and len(val_days) > warmup_days:
            val_days = val_days[warmup_days:]
        if warmup_days > 0 and len(test_days) > warmup_days:
            test_days = test_days[warmup_days:]

        # 7) Purge：在训练集中剔除靠近 val/test 开始的若干天，避免信息泄漏
        if purge_days > 0:
            boundaries = []
            # val/test 的“使用起点”应考虑 warmup 后的第一天
            if val_days:
                boundaries.append(val_days[0])
            if test_days:
                boundaries.append(test_days[0])

            purge_set = set()
            for b in boundaries:
                start = b - timedelta(days=purge_days)
                end = b  # [start, b)
                purge_set.update({d for d in train_days if start <= d < end})
            train_days = train_days - purge_set

        # 8) 映射成路径
        def pick(day_list):
            s = set(day_list)
            out = []
            for d in sorted(s):
                out.extend(by_date[d])
            return sorted(out)

        folds.append({
            "train": pick(train_days),
            "val":   pick(val_days),
            "test":  pick(test_days),
            "meta": {
                "train_months": train_ms,
                "val_month": val_m,
                "test_month": test_m,
                "train_span": (min(train_days) if train_days else None, max(train_days) if train_days else None),
                "val_span": (val_days[0] if val_days else None, val_days[-1] if val_days else None),
                "test_span": (test_days[0] if test_days else None, test_days[-1] if test_days else None),
            }
        })

    return folds


all_paths = glob("../data/*")

folds = rolling_month_splits(
    all_paths,
    train_months=4,
    val_months=1,
    test_months=1,
    purge_days=0,      # 没有跨日标签就设 0~1；有 t+1 日收益就 1
    warmup_days=0      # 如果月内首几天特征不稳定，设 1~3
)

print(f"Total folds: {len(folds)}")
for k, f in enumerate(folds, 1):
    meta = f["meta"]
    print(f"[Fold {k}] train={meta['train_months']} | val={meta['val_month']} | test={meta['test_month']}")
    print(f"  train_files={len(f['train'])}, val_files={len(f['val'])}, test_files={len(f['test'])}")


Total folds: 4
[Fold 1] train=[(2024, 12), (2025, 1), (2025, 2), (2025, 3)] | val=(2025, 4) | test=(2025, 5)
  train_files=340, val_files=105, test_files=95
[Fold 2] train=[(2025, 1), (2025, 2), (2025, 3), (2025, 4)] | val=(2025, 5) | test=(2025, 6)
  train_files=390, val_files=95, test_files=100
[Fold 3] train=[(2025, 2), (2025, 3), (2025, 4), (2025, 5)] | val=(2025, 6) | test=(2025, 7)
  train_files=395, val_files=100, test_files=115
[Fold 4] train=[(2025, 3), (2025, 4), (2025, 5), (2025, 6)] | val=(2025, 7) | test=(2025, 8)
  train_files=405, val_files=115, test_files=16
