In [1]:
import pandas as pd
import numpy as np
import my_modules

In [2]:
# サンプルデータの作成
# groups は 'id_for_fold' (レースID) に相当し、時系列順になっていると仮定
groups_sample = np.array([
    'race_20230101_R1', 'race_20230101_R1', 'race_20230101_R1', # レース1 (3頭)
    'race_20230101_R2', 'race_20230101_R2',                     # レース2 (2頭)
    'race_20230105_R1', 'race_20230105_R1', 'race_20230105_R1', # レース3 (3頭)
    'race_20230105_R2', 'race_20230105_R2',                     # レース4 (2頭)
    'race_20230108_R1', 'race_20230108_R1',                     # レース5 (2頭)
    'race_20230108_R2', 'race_20230108_R2', 'race_20230108_R2', # レース6 (3頭)
    'race_20230110_R1', 'race_20230110_R1',                     # レース7 (2頭)
    'race_20230110_R2', 'race_20230110_R2',                     # レース8 (2頭)
])

# X は特徴量データ。行数は groups_sample と同じであれば内容はダミーでOK
X_sample = np.arange(len(groups_sample) * 2).reshape(len(groups_sample), 2)
# y は目的変数。同様にダミーでOK
y_sample = np.arange(len(groups_sample))

# GroupTimeSeriesSplit のインスタンスを作成
n_total_unique_groups = len(pd.Series(groups_sample).unique())
print(f"Total unique groups (races): {n_total_unique_groups}\n")

# n_splits を調整して確認しやすくする
# 例: n_total_unique_groups = 8, n_splits = 3
# test_size = 8 // (3 + 1) = 2
# Fold 0: Train groups[:2], Test groups[2:4]
# Fold 1: Train groups[:4], Test groups[4:6]
# Fold 2: Train groups[:6], Test groups[6:]
custom_splitter = my_modules.GroupTimeSeriesSplit(n_splits=3)

# 分割を実行し、結果を表示
for i, (train_indices, test_indices) in enumerate(custom_splitter.split(X_sample, y_sample, groups=groups_sample)):
    print(f"--- Fold {i+1} ---")
    
    train_groups_in_fold = pd.Series(groups_sample[train_indices]).unique()
    test_groups_in_fold  = pd.Series(groups_sample[test_indices]).unique()
    
    print(f"  Train Indices: {train_indices}")
    print(f"  Train Groups : {sorted(list(train_groups_in_fold))}") # グループIDをソートして表示
    print(f"  Test Indices : {test_indices}")
    print(f"  Test Groups  : {sorted(list(test_groups_in_fold))}")  # グループIDをソートして表示
    print("-" * 20)

    # --- 確認ポイントのための追加チェック (プログラムによる補助) ---
    # 1. 訓練セットとテストセットでグループが重複していないか
    common_groups = set(train_groups_in_fold).intersection(set(test_groups_in_fold))
    if common_groups:
        print(f"  [NG] Fold {i+1}: Common groups found between train and test: {common_groups}")
    else:
        print(f"  [OK] Fold {i+1}: No common groups between train and test.")

    # 2. テストセットの最初のグループが訓練セットの最後のグループより後か
    # (グループIDが時系列順にソート可能である文字列または数値であると仮定)
    if len(train_groups_in_fold) > 0 and len(test_groups_in_fold) > 0:
        last_train_group = sorted(list(train_groups_in_fold))[-1]
        first_test_group = sorted(list(test_groups_in_fold))[0]
        if first_test_group <= last_train_group: # 文字列比較で時系列性を判定
            print(f"  [NG] Fold {i+1}: Test groups do not strictly follow train groups. "
                  f"Last train: {last_train_group}, First test: {first_test_group}")
        else:
            print(f"  [OK] Fold {i+1}: Test groups strictly follow train groups.")
    elif len(test_groups_in_fold) == 0:
        print(f"  [Info] Fold {i+1}: Test set is empty (no test groups). This might be an issue if not the last fold or if unexpected.")
    elif len(train_groups_in_fold) == 0:
         print(f"  [Info] Fold {i+1}: Train set is empty (no train groups). This might be an issue.")
    print("\n")

Total unique groups (races): 8

--- Fold 1 ---
  Train Indices: [0 1 2 3 4]
  Train Groups : ['race_20230101_R1', 'race_20230101_R2']
  Test Indices : [5 6 7 8 9]
  Test Groups  : ['race_20230105_R1', 'race_20230105_R2']
--------------------
  [OK] Fold 1: No common groups between train and test.
  [OK] Fold 1: Test groups strictly follow train groups.


--- Fold 2 ---
  Train Indices: [0 1 2 3 4 5 6 7 8 9]
  Train Groups : ['race_20230101_R1', 'race_20230101_R2', 'race_20230105_R1', 'race_20230105_R2']
  Test Indices : [10 11 12 13 14]
  Test Groups  : ['race_20230108_R1', 'race_20230108_R2']
--------------------
  [OK] Fold 2: No common groups between train and test.
  [OK] Fold 2: Test groups strictly follow train groups.


--- Fold 3 ---
  Train Indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
  Train Groups : ['race_20230101_R1', 'race_20230101_R2', 'race_20230105_R1', 'race_20230105_R2', 'race_20230108_R1', 'race_20230108_R2']
  Test Indices : [15 16 17 18]
  Test Groups  