In [None]:
####データセット全体にて標準化し，保存し直す


import random
from pathlib import Path
from datetime import datetime

import numpy as np

# ===================== 設定 =====================

# 元データセット（クラスディレクトリ配下に .npy）
DATASET_ROOT = Path("./All_materials/")

# 出力先（ここ配下に timestamp/クラス名/*.npy を作る）
OUTPUT_ROOT = Path("./normalized_dataset_forplot")

# オフセット範囲（両端含む）
OFFSET_T0 = 0
OFFSET_T1 = 100

# 数値安定化
EPS = 1e-8

# 保存 dtype（float32 推奨）
SAVE_DTYPE = np.float32

# 既に出力がある場合にスキップするか
SKIP_IF_EXISTS = True


# ===================== ユーティリティ =====================

def list_classes(dataset_root: Path):
    classes = [d.name for d in dataset_root.iterdir() if d.is_dir()]
    classes.sort()
    return classes

def list_npy_files(class_dir: Path):
    files = [p for p in class_dir.glob("*.npy")]
    files.sort()
    return files

def iter_all_npy_files(dataset_root: Path):
    """(class_name, path) を全て列挙"""
    for cls in list_classes(dataset_root):
        class_dir = dataset_root / cls
        for p in list_npy_files(class_dir):
            yield cls, p


# ===================== 統計量（global mean/std）の推定 =====================

def compute_global_stats_with_offset(dataset_root: Path,
                                     offset_t0: int = 0,
                                     offset_t1: int = 100,
                                     eps: float = 1e-8):
    """
    全 .npy を対象に:
      1) 各 trial で t=offset_t0..offset_t1 の平均を引く（taxel×軸ごと）
      2) オフセット除去後の global mean/std を taxel×軸ごとに推定

    Returns:
      mean: (n_taxels, 3)
      std : (n_taxels, 3)
    """
    paths = list(iter_all_npy_files(dataset_root))
    if not paths:
        raise RuntimeError(f"No .npy files found under: {dataset_root}")

    sum_ = None
    sumsq_ = None
    total_count = 0
    n_taxels_ref = None

    for i, (_, path) in enumerate(paths):
        data = np.load(path)  # (T, n_taxels, 3)

        if data.ndim != 3 or data.shape[-1] != 3:
            raise ValueError(f"Invalid shape {data.shape} in {path} (expected (T, n_taxels, 3))")

        T, n_taxels, _ = data.shape
        if T <= offset_t1:
            raise ValueError(f"T={T} too short for offset range {offset_t0}..{offset_t1} in {path}")

        if n_taxels_ref is None:
            n_taxels_ref = n_taxels
            sum_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
            sumsq_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
        elif n_taxels != n_taxels_ref:
            raise ValueError(f"n_taxels mismatch: {n_taxels} != {n_taxels_ref} in {path}")

        baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
        data0 = (data - baseline).astype(np.float64)  # (T, n_taxels, 3)

        sum_ += data0.sum(axis=0)             # (n_taxels, 3)
        sumsq_ += (data0 * data0).sum(axis=0) # (n_taxels, 3)
        total_count += T

        if (i + 1) % 200 == 0:
            print(f"[global-stats] processed {i+1}/{len(paths)} files...")

    mean = sum_ / total_count
    var = (sumsq_ / total_count) - (mean * mean)
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + eps)

    return mean.astype(np.float32), std.astype(np.float32)


def normalize_trial(data: np.ndarray, mean: np.ndarray, std: np.ndarray,
                    offset_t0: int = 0, offset_t1: int = 100):
    """
    data: (T, n_taxels, 3)
    mean/std: (n_taxels, 3) （オフセット後のglobal統計）
    """
    T, n_taxels, _ = data.shape
    if T <= offset_t1:
        raise ValueError(f"T={T} too short for offset range {offset_t0}..{offset_t1}")

    baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
    data0 = data - baseline
    data_norm = (data0 - mean[None, :, :]) / std[None, :, :]
    return data_norm


# ===================== 保存処理（同ディレクトリ構造） =====================

def save_normalized_dataset(dataset_root: Path,
                            out_root: Path,
                            mean: np.ndarray,
                            std: np.ndarray,
                            offset_t0: int = 0,
                            offset_t1: int = 100,
                            save_dtype=np.float32,
                            skip_if_exists: bool = True):
    """
    DATASET_ROOT の構造:
      dataset_root/
        classA/*.npy
        classB/*.npy
        ...

    OUTPUT の構造:
      out_root/
        classA/*.npy
        classB/*.npy
        ...

    を維持して保存する。
    """
    all_items = list(iter_all_npy_files(dataset_root))
    print(f"[save] total files to process: {len(all_items)}")

    for i, (cls, path) in enumerate(all_items):
        rel_class_dir = cls
        out_class_dir = out_root / rel_class_dir
        out_class_dir.mkdir(parents=True, exist_ok=True)

        out_path = out_class_dir / path.name
        if skip_if_exists and out_path.exists():
            continue

        data = np.load(path)  # (T, n_taxels, 3)
        data_norm = normalize_trial(data, mean, std, offset_t0, offset_t1).astype(save_dtype)

        # 念のため shape チェック
        if data_norm.shape != data.shape:
            raise RuntimeError(f"shape changed: {data.shape} -> {data_norm.shape} at {path}")

        np.save(out_path, data_norm)

        if (i + 1) % 200 == 0:
            print(f"[save] processed {i+1}/{len(all_items)} files...")

    print("[save] done.")


# ===================== main =====================

if __name__ == "__main__":
    # timestamp 付きの出力ディレクトリを作る
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    RUN_OUT_DIR = OUTPUT_ROOT / timestamp_str
    RUN_OUT_DIR.mkdir(parents=True, exist_ok=True)

    print(f"DATASET_ROOT : {DATASET_ROOT}")
    print(f"OUTPUT_DIR   : {RUN_OUT_DIR}")
    print(f"OFFSET range : {OFFSET_T0}..{OFFSET_T1} (inclusive)")

    # 1) global mean/std を計算（オフセット後）
    print("\n[1/2] computing global mean/std ...")
    global_mean, global_std = compute_global_stats_with_offset(
        DATASET_ROOT, offset_t0=OFFSET_T0, offset_t1=OFFSET_T1, eps=EPS
    )
    print("[1/2] done.")
    print("  mean shape:", global_mean.shape, "std shape:", global_std.shape)
    print("  std  min/max:", float(global_std.min()), float(global_std.max()))

    # 統計量と設定を保存（再現用）
    np.save(RUN_OUT_DIR / "global_mean_taxel_axis.npy", global_mean)
    np.save(RUN_OUT_DIR / "global_std_taxel_axis.npy", global_std)
    with open(RUN_OUT_DIR / "normalization_info.txt", "w", encoding="utf-8") as f:
        f.write(f"DATASET_ROOT={DATASET_ROOT}\n")
        f.write(f"OFFSET_T0={OFFSET_T0}\n")
        f.write(f"OFFSET_T1={OFFSET_T1}\n")
        f.write(f"EPS={EPS}\n")
        f.write(f"SAVE_DTYPE={save_dtype_to_str(SAVE_DTYPE) if 'save_dtype_to_str' in globals() else str(SAVE_DTYPE)}\n")
        f.write(f"SKIP_IF_EXISTS={SKIP_IF_EXISTS}\n")

    # 2) 全ファイルを正規化して同構造で保存
    print("\n[2/2] normalizing and saving dataset ...")
    save_normalized_dataset(
        dataset_root=DATASET_ROOT,
        out_root=RUN_OUT_DIR,
        mean=global_mean,
        std=global_std,
        offset_t0=OFFSET_T0,
        offset_t1=OFFSET_T1,
        save_dtype=SAVE_DTYPE,
        skip_if_exists=SKIP_IF_EXISTS
    )
    print(f"\nAll done. Normalized dataset saved under: {RUN_OUT_DIR}")
