In [17]:
import pandas as pd
import numpy as np

SEED = 42
N_SPLITS = 5

def add_season_column(df: pd.DataFrame, date_col="Sampling_Date") -> pd.DataFrame:
    df = df.copy()
    df[date_col] = pd.to_datetime(df[date_col], errors="coerce")

    m = df[date_col].dt.month
    # 南半球季节：DJF=summer, MAM=autumn, JJA=winter, SON=spring
    df["Season"] = np.select(
        [m.isin([12, 1, 2]), m.isin([3, 4, 5]), m.isin([6, 7, 8]), m.isin([9, 10, 11])],
        ["summer", "autumn", "winter", "spring"],
        default="unknown"
    )
    return df

def make_folds_state_season(df: pd.DataFrame) -> pd.DataFrame:
    """
    输入：long-format df（同一 image_id 有 5 行不同 target_name）
    输出：原 df 增加 fold 列（0..N_SPLITS-1）
    """
    df = df.copy()

    # 1) image_id
    df["image_id"] = df["sample_id"].str.split("__").str[0]

    # 2) season
    df = add_season_column(df, "Sampling_Date")

    # 3) 只在“每张图一条记录”层面 split，避免 5 行重复放大
    meta = (
        df.drop_duplicates("image_id")[["image_id", "State", "Season"]]
        .copy()
        .reset_index(drop=True)
    )

    # 有缺失的话先兜底（一般不建议丢数据，先填 Unknown）
    meta["State"] = meta["State"].fillna("Unknown")
    meta["Season"] = meta["Season"].fillna("Unknown")

    # 4) stratify key：State|Season
    meta["stratify_key"] = meta["State"].astype(str) + "|" + meta["Season"].astype(str)

    # 5) 先尝试 StratifiedGroupKFold（推荐）
    meta["fold"] = -1
    try:
        from sklearn.model_selection import StratifiedGroupKFold

        sgkf = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        for fold, (_, va_idx) in enumerate(
            sgkf.split(meta, y=meta["stratify_key"], groups=meta["image_id"])
        ):
            meta.loc[va_idx, "fold"] = fold

    except Exception as e:
        # 6) 回退：贪心分配（不依赖 sklearn 新特性）
        # 目标：每个 fold 的 stratify_key 计数尽量接近全局比例，同时保持 fold 总量均衡
        rng = np.random.default_rng(SEED)

        keys = meta["stratify_key"].unique().tolist()
        rng.shuffle(keys)

        # fold 统计
        fold_counts = [0] * N_SPLITS
        fold_key_counts = [dict() for _ in range(N_SPLITS)]

        # 全局 key 计数（用于计算目标比例）
        global_key_counts = meta["stratify_key"].value_counts().to_dict()
        total_n = len(meta)

        # 按 key 分组分配，先分大 key 再分小 key，更稳
        key_groups = meta.groupby("stratify_key")["image_id"].apply(list).to_dict()
        keys_sorted = sorted(key_groups.keys(), key=lambda k: len(key_groups[k]), reverse=True)

        for k in keys_sorted:
            imgs = key_groups[k]
            rng.shuffle(imgs)

            for image_id in imgs:
                # 为每个 image_id 选择“代价最小”的 fold
                best_fold = None
                best_score = None

                for f in range(N_SPLITS):
                    # 代价 = 总量不均衡惩罚 + key 比例偏差惩罚
                    new_total = fold_counts[f] + 1
                    desired_total = total_n / N_SPLITS

                    new_key_cnt = fold_key_counts[f].get(k, 0) + 1
                    desired_key = global_key_counts[k] / N_SPLITS

                    score = (new_total - desired_total) ** 2 + 3.0 * (new_key_cnt - desired_key) ** 2

                    if best_score is None or score < best_score:
                        best_score = score
                        best_fold = f

                # 落盘
                fold_counts[best_fold] += 1
                fold_key_counts[best_fold][k] = fold_key_counts[best_fold].get(k, 0) + 1
                meta.loc[meta["image_id"] == image_id, "fold"] = best_fold

        # 如果还有没分到的（理论上不会），随机补齐
        undecided = meta["fold"] < 0
        if undecided.any():
            meta.loc[undecided, "fold"] = rng.integers(0, N_SPLITS, size=undecided.sum())

    # 7) merge 回原 df（long-format）
    df = df.merge(meta[["image_id", "fold"]], on="image_id", how="left")

    return df

In [18]:
# ====== 使用方式 ======
df = pd.read_csv("/home/ecs-user/code/happen/kaggle-csrio/csiro-biomass/train.csv")
df = make_folds_state_season(df)

In [19]:


# ====== 验证：看每折 State×Season 分布是否接近 ======
def show_fold_distribution(df: pd.DataFrame):
    one = df.drop_duplicates("image_id")[["image_id", "State", "Sampling_Date", "fold"]].copy()
    one = add_season_column(one, "Sampling_Date")

    print("Per-fold image counts:")
    print(one["fold"].value_counts().sort_index(), "\n")

    print("Per-fold State counts:")
    print(pd.crosstab(one["fold"], one["State"]), "\n")

    print("Per-fold State×Season counts:")
    print(pd.crosstab(one["fold"], [one["State"], one["Season"]]))

show_fold_distribution(df)


Per-fold image counts:
fold
0    72
1    71
2    71
3    72
4    71
Name: count, dtype: int64 

Per-fold State counts:
State  NSW  Tas  Vic  WA
fold                    
0       12   28   24   8
1       18   33   16   4
2       17   21   28   5
3       17   27   22   6
4       11   29   22   9 

Per-fold State×Season counts:
State     NSW                  Tas                  Vic            WA       
Season autumn spring summer autumn spring winter spring winter spring winter
fold                                                                        
0           5      2      5      5     15      8      9     15      3      5
1           8      4      6      4     20      9      7      9      2      2
2           3      2     12      7      7      7     10     18      5      0
3           5      2     10      8     10      9      6     16      1      5
4           2      1      8      5     19      5      7     15      1      8


In [20]:
def fix_zero_cells_by_move(meta, key_col="stratify_key", group_col="image_id", fold_col="fold",
                           max_passes=50, seed=42):
    """
    meta: 每张图一行，包含 group_col、fold_col、key_col
    目标：消除 “某 key 在某 fold 计数为 0” 的情况（尽量做到每折都有覆盖）
    做法：从该 key 计数最多的 fold 挪一个样本到缺失 fold。
    """
    rng = np.random.default_rng(seed)
    meta = meta.copy()

    for _ in range(max_passes):
        ct = pd.crosstab(meta[fold_col], meta[key_col])
        zeros = np.argwhere(ct.to_numpy() == 0)
        if len(zeros) == 0:
            break

        changed = False
        folds = ct.index.to_list()
        keys = ct.columns.to_list()

        # 随机处理零格子，避免固定顺序陷入局部最优
        rng.shuffle(zeros)

        for i, j in zeros:
            recv_fold = folds[i]
            key = keys[j]

            # 找 donor：同 key 计数最多且 >1（避免 donor 也变 0）
            col = ct[key]
            donor_fold = col.idxmax()
            if col.loc[donor_fold] <= 1:
                continue  # 这个 key 太稀了，无法保证每折都有

            # 从 donor_fold 中挑一张该 key 的图片挪到 recv_fold
            candidates = meta[(meta[fold_col] == donor_fold) & (meta[key_col] == key)]
            if candidates.empty:
                continue

            pick = candidates.sample(1, random_state=int(rng.integers(0, 1e9))).index[0]
            meta.loc[pick, fold_col] = recv_fold

            # 更新 ct（局部更新最简单：重算一次）
            ct = pd.crosstab(meta[fold_col], meta[key_col])
            changed = True

            # 这个零格子解决了就继续下一个
        if not changed:
            break

    return meta

In [21]:
# df 是你 long-format 的训练表，且已经有 df["fold"] 列
meta = (
    df.drop_duplicates("image_id")[["image_id", "State", "Sampling_Date", "fold"]]
    .copy()
)

# 生成 Season（用你之前的 add_season_column）
meta = add_season_column(meta, "Sampling_Date")

# stratify_key
meta["State"] = meta["State"].fillna("Unknown")
meta["Season"] = meta["Season"].fillna("Unknown")
meta["stratify_key"] = meta["State"].astype(str) + "|" + meta["Season"].astype(str)

# 现在 meta 就可以喂给 fix_zero_cells_by_move 了
meta_fixed = fix_zero_cells_by_move(meta, key_col="stratify_key",
                                    group_col="image_id", fold_col="fold")

In [22]:
df = df.drop(columns=["fold"]).merge(
    meta_fixed[["image_id", "fold"]],
    on="image_id",
    how="left"
)

In [23]:
one = meta_fixed.copy()
print(pd.crosstab(one["fold"], [one["State"], one["Season"]]))

State     NSW                  Tas                  Vic            WA       
Season autumn spring summer autumn spring winter spring winter spring winter
fold                                                                        
0           5      2      5      5     15      8      9     15      3      5
1           8      4      6      4     20      9      7      9      2      2
2           3      2     12      7      7      7     10     18      5      1
3           5      2     10      8     10      9      6     16      1      5
4           2      1      8      5     19      5      7     15      1      7


In [25]:
fold_0 = one.copy().loc[one["fold"] == 0]
fold_1 = one.copy().loc[one["fold"] == 1]
fold_2 = one.copy().loc[one["fold"] == 2]
fold_3 = one.copy().loc[one["fold"] == 3]
fold_4 = one.copy().loc[one["fold"] == 4]

In [27]:
fold_0

Unnamed: 0,image_id,State,Sampling_Date,fold,Season,stratify_key
0,ID1011485656,Tas,2015-09-04,0,spring,Tas|spring
25,ID1036339023,Vic,2015-09-30,0,spring,Vic|spring
55,ID1058383417,Tas,2015-05-19,0,autumn,Tas|autumn
75,ID1084819986,Tas,2015-09-04,0,spring,Tas|spring
100,ID1113121340,Vic,2015-09-30,0,spring,Vic|spring
...,...,...,...,...,...,...
1685,ID896386823,Tas,2015-09-11,0,spring,Tas|spring
1710,ID930534670,Tas,2015-06-26,0,winter,Tas|winter
1735,ID956512130,Tas,2015-11-09,0,spring,Tas|spring
1760,ID975115267,WA,2015-07-08,0,winter,WA|winter


In [24]:
# 1) State×Season 计数表
tab = pd.crosstab(one["fold"], [one["State"], one["Season"]])
print("=== Per-fold State×Season counts ===")
print(tab)

# 2) 是否还有 0 格子
zero_cells = (tab == 0).sum().sum()
print("\n=== Zero cells ===")
print("zero_cells =", int(zero_cells))

# 3) 每个 (State,Season) 在各折的最小计数（越小越容易某折缺失）
min_per_key = tab.min(axis=0).sort_values()
print("\n=== Min count per (State,Season) across folds (lowest 15) ===")
print(min_per_key.head(15))

# 4) 各州在每折的计数 + 波动（std / mean）
state_tab = pd.crosstab(one["fold"], one["State"])
state_mean = state_tab.mean(axis=0)
state_std  = state_tab.std(axis=0)
print("\n=== Per-fold State counts ===")
print(state_tab)
print("\n=== State fold-balance (std/mean) ===")
print((state_std / (state_mean + 1e-9)).sort_values(ascending=False))

# 5) 一个总体“域均衡”分数：每个 key 的折间变异系数均值（越低越均衡）
key_mean = tab.mean(axis=0)
key_std  = tab.std(axis=0)
cv_score = (key_std / (key_mean + 1e-9)).replace([np.inf, -np.inf], np.nan)
print("\n=== Overall key imbalance score (mean CV over keys) ===")
print(float(cv_score.mean()))

=== Per-fold State×Season counts ===
State     NSW                  Tas                  Vic            WA       
Season autumn spring summer autumn spring winter spring winter spring winter
fold                                                                        
0           5      2      5      5     15      8      9     15      3      5
1           8      4      6      4     20      9      7      9      2      2
2           3      2     12      7      7      7     10     18      5      1
3           5      2     10      8     10      9      6     16      1      5
4           2      1      8      5     19      5      7     15      1      7

=== Zero cells ===
zero_cells = 0

=== Min count per (State,Season) across folds (lowest 15) ===
State  Season
NSW    spring    1
WA     spring    1
       winter    1
NSW    autumn    2
Tas    autumn    4
NSW    summer    5
Tas    winter    5
Vic    spring    6
Tas    spring    7
Vic    winter    9
dtype: int64

=== Per-fold State counts ===
S