In [1]:
import numpy as np
import os
from sklearn.model_selection import train_test_split
import json

# ---------------------------
data_dir = "data_trials_preprocessed"
save_dir = "data_processed"
split_dir = "splits"
os.makedirs(save_dir, exist_ok=True)
os.makedirs(split_dir, exist_ok=True)

# 分组方案
train_ids = [f"A0{i}T" for i in range(1,6)]
val_ood_ids = ["A06T", "A07T"]
test_ood_ids = ["A08T", "A09T"]

# ----------------------------------------------------------
def load_npz(subj):
    data = np.load(os.path.join(data_dir, f"{subj}.npz"))
    return data["X"], data["y"]

def stratified_split(X, y, test_size=0.15, val_size=0.176, seed=42):
    # 与你SL.ipynb中比例一致：train:val:test ≈ 70:15:15
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y)
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=val_size,
        random_state=seed, stratify=y_train_val)
    return (X_train, X_val, X_test, y_train, y_val, y_test)

# ----------------------------------------------------------
# 处理 A01–A05 (ID 组)
train_X_all, val_X_all, test_X_all = [], [], []
train_y_all, val_y_all, test_y_all = [], [], []

for subj in train_ids:
    X, y = load_npz(subj)
    X_train, X_val, X_test, y_train, y_val, y_test = stratified_split(X, y)
    train_X_all.append(X_train); val_X_all.append(X_val); test_X_all.append(X_test)
    train_y_all.append(y_train); val_y_all.append(y_val); test_y_all.append(y_test)
    print(f"{subj}: train={len(y_train)}, val={len(y_val)}, test={len(y_test)}")

# 合并 pooled 数据
def concat_save(X_list, y_list, name):
    X = np.concatenate(X_list, axis=0)
    y = np.concatenate(y_list, axis=0)
    np.savez_compressed(os.path.join(save_dir, name), X=X, y=y)
    print(f" Saved {name}: {X.shape}")

concat_save(train_X_all, train_y_all, "train_5subj.npz")
concat_save(val_X_all, val_y_all, "val_id_5subj.npz")
concat_save(test_X_all, test_y_all, "test_id_5subj.npz")

# ----------------------------------------------------------
# Val-OOD (A06, A07)
val_ood_X, val_ood_y = [], []
for subj in val_ood_ids:
    X, y = load_npz(subj)
    val_ood_X.append(X); val_ood_y.append(y)
concat_save(val_ood_X, val_ood_y, "val_ood_2subj.npz")

# Test-OOD (A08, A09)
test_ood_X, test_ood_y = [], []
for subj in test_ood_ids:
    X, y = load_npz(subj)
    test_ood_X.append(X); test_ood_y.append(y)
concat_save(test_ood_X, test_ood_y, "test_ood_2subj.npz")

# ----------------------------------------------------------
# 保存分组信息
splits_info = {
    "train_ids": train_ids,
    "val_ood_ids": val_ood_ids,
    "test_ood_ids": test_ood_ids
}
with open(os.path.join(split_dir, "subjects.json"), "w") as f:
    json.dump(splits_info, f, indent=2)
print(" Saved splits/subjects.json")


A01T: train=201, val=43, test=44
A02T: train=201, val=43, test=44
A03T: train=201, val=43, test=44
A04T: train=201, val=43, test=44
A05T: train=201, val=43, test=44
 Saved train_5subj.npz: (1005, 22, 1000)
 Saved val_id_5subj.npz: (215, 22, 1000)
 Saved test_id_5subj.npz: (220, 22, 1000)
 Saved val_ood_2subj.npz: (576, 22, 1000)
 Saved test_ood_2subj.npz: (576, 22, 1000)
 Saved splits/subjects.json
