In [8]:
import os
import random
from pathlib import Path
from datetime import datetime  # ★ タイムスタンプ用

import numpy as np
import matplotlib.pyplot as plt

# ===================== 設定 =====================

# ★自分の環境に合わせて書き換えてください
DATASET_ROOT = Path("./normalized_dataset_forplot/20251215_161803")  # 触覚データのルート

# プロットの範囲（いずれかを選ぶ）
PLOT_SCOPE = "sample_across_classes"  # "sample_across_classes" / "all_in_one_class" / "single_file"

# 1. 全クラスから任意の数のデータをプロット
SAMPLES_PER_CLASS = 3      # 各クラスから何ファイルずつ取るか
RANDOM_SAMPLE = True       # True: ランダムサンプリング, False: 先頭から順に

# 2. あるクラスにおける全てのデータをプロット
TARGET_CLASS = "paper_A4"  # PLOT_SCOPE="all_in_one_class" のとき使用

# 3. 単一ファイルだけプロット
SINGLE_CLASS = "paper_A4"
SINGLE_FILE_INDEX = 0      # そのクラス内で何番目の .npy をプロットするか


# ===================== ユーティリティ =====================

def list_classes(dataset_root: Path):
    """クラス名（カテゴリ名）のディレクトリ一覧を返す"""
    classes = [d.name for d in dataset_root.iterdir() if d.is_dir()]
    classes.sort()
    return classes


def list_npy_files(class_dir: Path):
    """あるクラスディレクトリ内の .npy ファイル一覧を返す"""
    files = [p for p in class_dir.glob("*.npy")]
    files.sort()
    return files


def summarize_dataset(dataset_root: Path):
    """データセット全体のざっくり概要を標準出力に出す"""
    print("=== Dataset summary ===")
    classes = list_classes(dataset_root)
    print(f"Found {len(classes)} classes:")
    for cls in classes:
        class_dir = dataset_root / cls
        files = list_npy_files(class_dir)
        print(f"  - {cls}: {len(files)} files")
        if files:
            sample = np.load(files[0])
            print(f"      sample shape: {sample.shape} (T, n_taxels, 3)")
    print("=======================\n")


# ===================== プロット関数 =====================

def plot_all_taxels_grouped_by_axis(data: np.ndarray,
                                    class_name: str,
                                    file_name: str,
                                    output_dir: Path):
    """
    1つの trial (1 .npy ファイル) をプロットする。
    data: shape (T, n_taxels, 3)

    - X, Y, Z の各軸ごとに1つのサブプロット
    - その中に 16タクセル分の線をすべて重ねて描画
    - 凡例はグラフにかぶらないように右側に配置
    - 画像を output_dir に保存する
    """
    assert data.ndim == 3, f"Expected 3D array, got {data.shape}"
    T, n_taxels, n_axes = data.shape
    assert n_axes == 3, f"Last dim must be 3 (x, y, z), got {n_axes}"

    time = np.arange(T)
    axis_names = ["x", "y", "z"]

    fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    for axis in range(3):
        ax = axes[axis]
        for taxel in range(n_taxels):
            # 凡例をすっきりさせるため、ラベルは最初のサブプロットだけ
            label = f"taxel {taxel}" if axis == 0 else "_nolegend_"
            ax.plot(time,
                    data[:, taxel, axis],
                    alpha=0.6,
                    linewidth=0.8,
                    label=label)
        ax.set_ylabel(f"{axis_names[axis]} (arb.)")
        ax.grid(True)

    axes[-1].set_xlabel("Time step")

    # 凡例は図の右側に配置して、グラフに重ならないようにする
    handles, labels = axes[0].get_legend_handles_labels()
    fig.suptitle(f"{class_name} | {file_name}")
    # 左側 85% をプロット用に確保し、右 15% を凡例用に空ける
    fig.tight_layout(rect=(0.0, 0.0, 0.85, 0.95))

    if handles:
        fig.legend(handles,
                   labels,
                   loc="center left",
                   bbox_to_anchor=(0.88, 0.5),
                   borderaxespad=0.)

    # ==== ここから保存処理 ====
    stem = Path(file_name).stem
    out_name = f"{class_name}_{stem}.png"
    out_path = output_dir / out_name
    fig.savefig(out_path, dpi=300)
    plt.close(fig)

    print(f"Saved figure to: {out_path}")


def plot_trials(trial_paths, output_dir: Path):
    """
    複数の npy ファイルを順番に読み込み、各 trial ごとに
    「XYZ軸ごとに16タクセルをまとめて」プロットし、保存する。
    trial_paths: list[Path]
    """
    for path in trial_paths:
        class_name = path.parent.name
        file_name = path.name
        print(f"Plotting: class={class_name}, file={file_name}")
        data = np.load(path)
        plot_all_taxels_grouped_by_axis(data, class_name, file_name, output_dir)


# ===================== メイン処理 =====================

if __name__ == "__main__":
    # ★ タイムスタンプ付きの出力ディレクトリを作成
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    OUTPUT_ROOT = Path("./plots")
    RUN_OUTPUT_DIR = OUTPUT_ROOT / timestamp_str
    RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {RUN_OUTPUT_DIR}")

    summarize_dataset(DATASET_ROOT)
    classes = list_classes(DATASET_ROOT)
    if not classes:
        raise RuntimeError("No class directories found under DATASET_ROOT")

    trial_paths = []

    if PLOT_SCOPE == "sample_across_classes":
        # 全クラスから任意個のファイルをピックアップ
        for cls in classes:
            class_dir = DATASET_ROOT / cls
            files = list_npy_files(class_dir)
            if not files:
                continue

            if RANDOM_SAMPLE:
                n = min(SAMPLES_PER_CLASS, len(files))
                sampled = random.sample(files, n)
            else:
                sampled = files[:SAMPLES_PER_CLASS]

            trial_paths.extend(sampled)

        # （必要ならシャッフル）
        random.shuffle(trial_paths)

    elif PLOT_SCOPE == "all_in_one_class":
        # あるクラスに含まれる全ファイルをプロット
        if TARGET_CLASS not in classes:
            raise ValueError(f"TARGET_CLASS '{TARGET_CLASS}' not found. Available: {classes}")
        class_dir = DATASET_ROOT / TARGET_CLASS
        trial_paths = list_npy_files(class_dir)
        if not trial_paths:
            raise RuntimeError(f"No .npy files in class directory: {class_dir}")

    elif PLOT_SCOPE == "single_file":
        # 単一ファイルのみプロット
        if SINGLE_CLASS not in classes:
            raise ValueError(f"SINGLE_CLASS '{SINGLE_CLASS}' not found. Available: {classes}")
        class_dir = DATASET_ROOT / SINGLE_CLASS
        files = list_npy_files(class_dir)
        if not files:
            raise RuntimeError(f"No .npy files in class directory: {class_dir}")
        if not (0 <= SINGLE_FILE_INDEX < len(files)):
            raise ValueError(f"SINGLE_FILE_INDEX {SINGLE_FILE_INDEX} out of range (0..{len(files)-1})")

        trial_paths = [files[SINGLE_FILE_INDEX]]

    else:
        raise ValueError(f"Unknown PLOT_SCOPE: {PLOT_SCOPE}")

    # 実際にプロット＆保存
    print(f"Total trials to plot: {len(trial_paths)}")
    plot_trials(trial_paths, RUN_OUTPUT_DIR)


Output directory: plots/20251215_162635
=== Dataset summary ===
Found 25 classes:
  - 01_table_cover: 10 files
      sample shape: (1255, 16, 3) (T, n_taxels, 3)
  - 02_fur_scarf: 10 files
      sample shape: (1216, 16, 3) (T, n_taxels, 3)
  - 03_washing_towel: 10 files
      sample shape: (1222, 16, 3) (T, n_taxels, 3)
  - 04_carpet1: 10 files
      sample shape: (1274, 16, 3) (T, n_taxels, 3)
  - 05_bubble_wrap: 10 files
      sample shape: (1229, 16, 3) (T, n_taxels, 3)
  - 06_fleece_scarf: 10 files
      sample shape: (1245, 16, 3) (T, n_taxels, 3)
  - 07_knit_hat1: 10 files
      sample shape: (1248, 16, 3) (T, n_taxels, 3)
  - 08_body_towel1: 10 files
      sample shape: (1246, 16, 3) (T, n_taxels, 3)
  - 09_body_towel2: 10 files
      sample shape: (1247, 16, 3) (T, n_taxels, 3)
  - 10_carpet2: 10 files
      sample shape: (1278, 16, 3) (T, n_taxels, 3)
  - 11_work_gloves: 10 files
      sample shape: (1247, 16, 3) (T, n_taxels, 3)
  - 12_knit_hat2: 10 files
      sample shape: 

In [None]:
####データセット全体にて標準化し，プロット

from torch.utils.data import DataLoader, random_split

# ルートディレクトリ（クラス名ディレクトリが並んでいる場所）
DATASET_ROOT = "./dataset_root"

# Dataset を作成
dataset = TactileSequenceDataset(
    root_dir=DATASET_ROOT,
    seq_start=400,
    seq_end=1200,
    dtype=torch.float32,
)

# 7:2:1 くらいで train/val/test に分割する例
total_len = len(dataset)
train_len = int(total_len * 0.7)
val_len   = int(total_len * 0.2)
test_len  = total_len - train_len - val_len

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_len, val_len, test_len],
    generator=torch.Generator().manual_seed(42)
)

BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
)

print("train_batches:", len(train_loader))
print("val_batches  :", len(val_loader))
print("test_batches :", len(test_loader))


In [6]:
####データセット全体にて標準化し，プロット


import os
import random
from pathlib import Path
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt


# ===================== 設定 =====================

DATASET_ROOT = Path("./All_materials")  # 触覚データのルート

PLOT_SCOPE = "sample_across_classes"  # "sample_across_classes" / "all_in_one_class" / "single_file"

SAMPLES_PER_CLASS = 3
RANDOM_SAMPLE = True

TARGET_CLASS = "paper_A4"  # PLOT_SCOPE="all_in_one_class" のとき使用

SINGLE_CLASS = "paper_A4"
SINGLE_FILE_INDEX = 0

# ---- 正規化設定（追加） ----
APPLY_GLOBAL_ZSCORE = True

OFFSET_T0 = 0
OFFSET_T1 = 100  # 0..100 を基準（両端含む）
EPS = 1e-8


# ===================== ユーティリティ =====================

def list_classes(dataset_root: Path):
    classes = [d.name for d in dataset_root.iterdir() if d.is_dir()]
    classes.sort()
    return classes


def list_npy_files(class_dir: Path):
    files = [p for p in class_dir.glob("*.npy")]
    files.sort()
    return files


def iter_all_npy_files(dataset_root: Path):
    """DATASET_ROOT配下の全 .npy を列挙（クラス横断）"""
    for cls in list_classes(dataset_root):
        class_dir = dataset_root / cls
        for p in list_npy_files(class_dir):
            yield p


def summarize_dataset(dataset_root: Path):
    print("=== Dataset summary ===")
    classes = list_classes(dataset_root)
    print(f"Found {len(classes)} classes:")
    for cls in classes:
        class_dir = dataset_root / cls
        files = list_npy_files(class_dir)
        print(f"  - {cls}: {len(files)} files")
        if files:
            sample = np.load(files[0])
            print(f"      sample shape: {sample.shape} (T, n_taxels, 3)")
    print("=======================\n")


# ===================== 正規化（追加） =====================

def compute_global_stats_with_offset(dataset_root: Path,
                                     offset_t0: int = 0,
                                     offset_t1: int = 100,
                                     eps: float = 1e-8):
    """
    1) 各trialでオフセット除去（t=offset_t0..offset_t1 の平均を引く）
    2) オフセット除去後の値を用いて、全 .npy を対象に global mean/std を
       各feature（taxel×軸）次元ごとに推定する。

    Returns:
      mean: shape (n_taxels, 3)
      std : shape (n_taxels, 3)
    """
    sum_ = None
    sumsq_ = None
    total_count = 0

    n_taxels_ref = None

    paths = list(iter_all_npy_files(dataset_root))
    if not paths:
        raise RuntimeError(f"No .npy files found under: {dataset_root}")

    for i, path in enumerate(paths):
        data = np.load(path)  # (T, n_taxels, 3)
        if data.ndim != 3 or data.shape[-1] != 3:
            raise ValueError(f"Invalid shape {data.shape} in {path} (expected (T, n_taxels, 3))")

        T, n_taxels, _ = data.shape
        if n_taxels_ref is None:
            n_taxels_ref = n_taxels
            sum_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
            sumsq_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
        else:
            if n_taxels != n_taxels_ref:
                raise ValueError(f"n_taxels mismatch: {n_taxels} != {n_taxels_ref} in {path}")

        if T <= offset_t1:
            raise ValueError(f"T={T} is too short for offset range 0..{offset_t1} in {path}")

        # trialごとのオフセット（taxel×軸ごと）
        baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
        data0 = (data - baseline).astype(np.float64)  # (T, n_taxels, 3)

        sum_ += data0.sum(axis=0)               # (n_taxels, 3)
        sumsq_ += (data0 * data0).sum(axis=0)   # (n_taxels, 3)
        total_count += T

        if (i + 1) % 200 == 0:
            print(f"[global-stats] processed {i+1}/{len(paths)} files...")

    mean = sum_ / total_count
    var = (sumsq_ / total_count) - (mean * mean)
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + eps)

    return mean.astype(np.float32), std.astype(np.float32)


def apply_offset_and_global_zscore(data: np.ndarray,
                                  mean: np.ndarray,
                                  std: np.ndarray,
                                  offset_t0: int = 0,
                                  offset_t1: int = 100):
    """
    data: (T, n_taxels, 3)
    mean/std: (n_taxels, 3)  ※オフセット除去後のglobal統計
    """
    T, n_taxels, _ = data.shape
    if T <= offset_t1:
        raise ValueError(f"T={T} is too short for offset range 0..{offset_t1}")

    baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
    data0 = data - baseline
    data_norm = (data0 - mean[None, :, :]) / std[None, :, :]
    return data_norm


# ===================== プロット関数（元のまま） =====================

def plot_all_taxels_grouped_by_axis(data: np.ndarray,
                                    class_name: str,
                                    file_name: str,
                                    output_dir: Path):
    assert data.ndim == 3, f"Expected 3D array, got {data.shape}"
    T, n_taxels, n_axes = data.shape
    assert n_axes == 3, f"Last dim must be 3 (x, y, z), got {n_axes}"

    time = np.arange(T)
    axis_names = ["x", "y", "z"]

    fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    for axis in range(3):
        ax = axes[axis]
        for taxel in range(n_taxels):
            label = f"taxel {taxel}" if axis == 0 else "_nolegend_"
            ax.plot(time,
                    data[:, taxel, axis],
                    alpha=0.6,
                    linewidth=0.8,
                    label=label)
        ax.set_ylabel(f"{axis_names[axis]} (arb.)")
        ax.grid(True)

    axes[-1].set_xlabel("Time step")

    handles, labels = axes[0].get_legend_handles_labels()
    fig.suptitle(f"{class_name} | {file_name}")
    fig.tight_layout(rect=(0.0, 0.0, 0.85, 0.95))

    if handles:
        fig.legend(handles,
                   labels,
                   loc="center left",
                   bbox_to_anchor=(0.88, 0.5),
                   borderaxespad=0.)

    stem = Path(file_name).stem
    out_name = f"{class_name}_{stem}.png"
    out_path = output_dir / out_name
    fig.savefig(out_path, dpi=300)
    plt.close(fig)

    print(f"Saved figure to: {out_path}")


def plot_trials(trial_paths, output_dir: Path, global_mean=None, global_std=None):
    for path in trial_paths:
        class_name = path.parent.name
        file_name = path.name
        print(f"Plotting: class={class_name}, file={file_name}")

        data = np.load(path)  # (T, n_taxels, 3)

        if APPLY_GLOBAL_ZSCORE:
            if global_mean is None or global_std is None:
                raise RuntimeError("global_mean/std are required when APPLY_GLOBAL_ZSCORE=True")
            data = apply_offset_and_global_zscore(
                data,
                mean=global_mean,
                std=global_std,
                offset_t0=OFFSET_T0,
                offset_t1=OFFSET_T1
            )

        plot_all_taxels_grouped_by_axis(data, class_name, file_name, output_dir)


# ===================== メイン処理 =====================

if __name__ == "__main__":
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    OUTPUT_ROOT = Path("./plots")
    RUN_OUTPUT_DIR = OUTPUT_ROOT / timestamp_str
    RUN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {RUN_OUTPUT_DIR}")

    summarize_dataset(DATASET_ROOT)
    classes = list_classes(DATASET_ROOT)
    if not classes:
        raise RuntimeError("No class directories found under DATASET_ROOT")

    # ---- 追加：global統計量の推定（データセット全体） ----
    global_mean, global_std = None, None
    if APPLY_GLOBAL_ZSCORE:
        print(f"\n[compute] global mean/std with offset t={OFFSET_T0}..{OFFSET_T1} (per feature: taxel×axis)")
        global_mean, global_std = compute_global_stats_with_offset(
            DATASET_ROOT, offset_t0=OFFSET_T0, offset_t1=OFFSET_T1, eps=EPS
        )
        print("[done] global stats computed.")
        print("  mean shape:", global_mean.shape, "std shape:", global_std.shape)
        print("  std min/max:", float(global_std.min()), float(global_std.max()))

        # 再現性のため保存（任意）
        np.save(RUN_OUTPUT_DIR / "global_mean_taxel_axis.npy", global_mean)
        np.save(RUN_OUTPUT_DIR / "global_std_taxel_axis.npy", global_std)
        with open(RUN_OUTPUT_DIR / "normalization_info.txt", "w", encoding="utf-8") as f:
            f.write(f"APPLY_GLOBAL_ZSCORE={APPLY_GLOBAL_ZSCORE}\n")
            f.write(f"OFFSET_T0={OFFSET_T0}\n")
            f.write(f"OFFSET_T1={OFFSET_T1}\n")
            f.write(f"EPS={EPS}\n")
        print(f"Saved normalization stats to: {RUN_OUTPUT_DIR}")

    trial_paths = []

    if PLOT_SCOPE == "sample_across_classes":
        for cls in classes:
            class_dir = DATASET_ROOT / cls
            files = list_npy_files(class_dir)
            if not files:
                continue

            if RANDOM_SAMPLE:
                n = min(SAMPLES_PER_CLASS, len(files))
                sampled = random.sample(files, n)
            else:
                sampled = files[:SAMPLES_PER_CLASS]

            trial_paths.extend(sampled)

        random.shuffle(trial_paths)

    elif PLOT_SCOPE == "all_in_one_class":
        if TARGET_CLASS not in classes:
            raise ValueError(f"TARGET_CLASS '{TARGET_CLASS}' not found. Available: {classes}")
        class_dir = DATASET_ROOT / TARGET_CLASS
        trial_paths = list_npy_files(class_dir)
        if not trial_paths:
            raise RuntimeError(f"No .npy files in class directory: {class_dir}")

    elif PLOT_SCOPE == "single_file":
        if SINGLE_CLASS not in classes:
            raise ValueError(f"SINGLE_CLASS '{SINGLE_CLASS}' not found. Available: {classes}")
        class_dir = DATASET_ROOT / SINGLE_CLASS
        files = list_npy_files(class_dir)
        if not files:
            raise RuntimeError(f"No .npy files in class directory: {class_dir}")
        if not (0 <= SINGLE_FILE_INDEX < len(files)):
            raise ValueError(f"SINGLE_FILE_INDEX {SINGLE_FILE_INDEX} out of range (0..{len(files)-1})")
        trial_paths = [files[SINGLE_FILE_INDEX]]

    else:
        raise ValueError(f"Unknown PLOT_SCOPE: {PLOT_SCOPE}")

    print(f"Total trials to plot: {len(trial_paths)}")
    plot_trials(trial_paths, RUN_OUTPUT_DIR, global_mean=global_mean, global_std=global_std)


Output directory: plots/20251215_155716
=== Dataset summary ===
Found 25 classes:
  - 01_table_cover: 10 files
      sample shape: (1255, 16, 3) (T, n_taxels, 3)
  - 02_fur_scarf: 10 files
      sample shape: (1216, 16, 3) (T, n_taxels, 3)
  - 03_washing_towel: 10 files
      sample shape: (1222, 16, 3) (T, n_taxels, 3)
  - 04_carpet1: 10 files
      sample shape: (1274, 16, 3) (T, n_taxels, 3)
  - 05_bubble_wrap: 10 files
      sample shape: (1229, 16, 3) (T, n_taxels, 3)
  - 06_fleece_scarf: 10 files
      sample shape: (1245, 16, 3) (T, n_taxels, 3)
  - 07_knit_hat1: 10 files
      sample shape: (1248, 16, 3) (T, n_taxels, 3)
  - 08_body_towel1: 10 files
      sample shape: (1246, 16, 3) (T, n_taxels, 3)
  - 09_body_towel2: 10 files
      sample shape: (1247, 16, 3) (T, n_taxels, 3)
  - 10_carpet2: 10 files
      sample shape: (1278, 16, 3) (T, n_taxels, 3)
  - 11_work_gloves: 10 files
      sample shape: (1247, 16, 3) (T, n_taxels, 3)
  - 12_knit_hat2: 10 files
      sample shape: 

In [7]:
####データセット全体にて標準化し，保存し直す


import random
from pathlib import Path
from datetime import datetime

import numpy as np

# ===================== 設定 =====================

# 元データセット（クラスディレクトリ配下に .npy）
DATASET_ROOT = Path("./All_materials/")

# 出力先（ここ配下に timestamp/クラス名/*.npy を作る）
OUTPUT_ROOT = Path("./normalized_dataset_forplot")

# オフセット範囲（両端含む）
OFFSET_T0 = 0
OFFSET_T1 = 100

# 数値安定化
EPS = 1e-8

# 保存 dtype（float32 推奨）
SAVE_DTYPE = np.float32

# 既に出力がある場合にスキップするか
SKIP_IF_EXISTS = True


# ===================== ユーティリティ =====================

def list_classes(dataset_root: Path):
    classes = [d.name for d in dataset_root.iterdir() if d.is_dir()]
    classes.sort()
    return classes

def list_npy_files(class_dir: Path):
    files = [p for p in class_dir.glob("*.npy")]
    files.sort()
    return files

def iter_all_npy_files(dataset_root: Path):
    """(class_name, path) を全て列挙"""
    for cls in list_classes(dataset_root):
        class_dir = dataset_root / cls
        for p in list_npy_files(class_dir):
            yield cls, p


# ===================== 統計量（global mean/std）の推定 =====================

def compute_global_stats_with_offset(dataset_root: Path,
                                     offset_t0: int = 0,
                                     offset_t1: int = 100,
                                     eps: float = 1e-8):
    """
    全 .npy を対象に:
      1) 各 trial で t=offset_t0..offset_t1 の平均を引く（taxel×軸ごと）
      2) オフセット除去後の global mean/std を taxel×軸ごとに推定

    Returns:
      mean: (n_taxels, 3)
      std : (n_taxels, 3)
    """
    paths = list(iter_all_npy_files(dataset_root))
    if not paths:
        raise RuntimeError(f"No .npy files found under: {dataset_root}")

    sum_ = None
    sumsq_ = None
    total_count = 0
    n_taxels_ref = None

    for i, (_, path) in enumerate(paths):
        data = np.load(path)  # (T, n_taxels, 3)

        if data.ndim != 3 or data.shape[-1] != 3:
            raise ValueError(f"Invalid shape {data.shape} in {path} (expected (T, n_taxels, 3))")

        T, n_taxels, _ = data.shape
        if T <= offset_t1:
            raise ValueError(f"T={T} too short for offset range {offset_t0}..{offset_t1} in {path}")

        if n_taxels_ref is None:
            n_taxels_ref = n_taxels
            sum_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
            sumsq_ = np.zeros((n_taxels_ref, 3), dtype=np.float64)
        elif n_taxels != n_taxels_ref:
            raise ValueError(f"n_taxels mismatch: {n_taxels} != {n_taxels_ref} in {path}")

        baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
        data0 = (data - baseline).astype(np.float64)  # (T, n_taxels, 3)

        sum_ += data0.sum(axis=0)             # (n_taxels, 3)
        sumsq_ += (data0 * data0).sum(axis=0) # (n_taxels, 3)
        total_count += T

        if (i + 1) % 200 == 0:
            print(f"[global-stats] processed {i+1}/{len(paths)} files...")

    mean = sum_ / total_count
    var = (sumsq_ / total_count) - (mean * mean)
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + eps)

    return mean.astype(np.float32), std.astype(np.float32)


def normalize_trial(data: np.ndarray, mean: np.ndarray, std: np.ndarray,
                    offset_t0: int = 0, offset_t1: int = 100):
    """
    data: (T, n_taxels, 3)
    mean/std: (n_taxels, 3) （オフセット後のglobal統計）
    """
    T, n_taxels, _ = data.shape
    if T <= offset_t1:
        raise ValueError(f"T={T} too short for offset range {offset_t0}..{offset_t1}")

    baseline = data[offset_t0:offset_t1 + 1].mean(axis=0, keepdims=True)  # (1, n_taxels, 3)
    data0 = data - baseline
    data_norm = (data0 - mean[None, :, :]) / std[None, :, :]
    return data_norm


# ===================== 保存処理（同ディレクトリ構造） =====================

def save_normalized_dataset(dataset_root: Path,
                            out_root: Path,
                            mean: np.ndarray,
                            std: np.ndarray,
                            offset_t0: int = 0,
                            offset_t1: int = 100,
                            save_dtype=np.float32,
                            skip_if_exists: bool = True):
    """
    DATASET_ROOT の構造:
      dataset_root/
        classA/*.npy
        classB/*.npy
        ...

    OUTPUT の構造:
      out_root/
        classA/*.npy
        classB/*.npy
        ...

    を維持して保存する。
    """
    all_items = list(iter_all_npy_files(dataset_root))
    print(f"[save] total files to process: {len(all_items)}")

    for i, (cls, path) in enumerate(all_items):
        rel_class_dir = cls
        out_class_dir = out_root / rel_class_dir
        out_class_dir.mkdir(parents=True, exist_ok=True)

        out_path = out_class_dir / path.name
        if skip_if_exists and out_path.exists():
            continue

        data = np.load(path)  # (T, n_taxels, 3)
        data_norm = normalize_trial(data, mean, std, offset_t0, offset_t1).astype(save_dtype)

        # 念のため shape チェック
        if data_norm.shape != data.shape:
            raise RuntimeError(f"shape changed: {data.shape} -> {data_norm.shape} at {path}")

        np.save(out_path, data_norm)

        if (i + 1) % 200 == 0:
            print(f"[save] processed {i+1}/{len(all_items)} files...")

    print("[save] done.")


# ===================== main =====================

if __name__ == "__main__":
    # timestamp 付きの出力ディレクトリを作る
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    RUN_OUT_DIR = OUTPUT_ROOT / timestamp_str
    RUN_OUT_DIR.mkdir(parents=True, exist_ok=True)

    print(f"DATASET_ROOT : {DATASET_ROOT}")
    print(f"OUTPUT_DIR   : {RUN_OUT_DIR}")
    print(f"OFFSET range : {OFFSET_T0}..{OFFSET_T1} (inclusive)")

    # 1) global mean/std を計算（オフセット後）
    print("\n[1/2] computing global mean/std ...")
    global_mean, global_std = compute_global_stats_with_offset(
        DATASET_ROOT, offset_t0=OFFSET_T0, offset_t1=OFFSET_T1, eps=EPS
    )
    print("[1/2] done.")
    print("  mean shape:", global_mean.shape, "std shape:", global_std.shape)
    print("  std  min/max:", float(global_std.min()), float(global_std.max()))

    # 統計量と設定を保存（再現用）
    np.save(RUN_OUT_DIR / "global_mean_taxel_axis.npy", global_mean)
    np.save(RUN_OUT_DIR / "global_std_taxel_axis.npy", global_std)
    with open(RUN_OUT_DIR / "normalization_info.txt", "w", encoding="utf-8") as f:
        f.write(f"DATASET_ROOT={DATASET_ROOT}\n")
        f.write(f"OFFSET_T0={OFFSET_T0}\n")
        f.write(f"OFFSET_T1={OFFSET_T1}\n")
        f.write(f"EPS={EPS}\n")
        f.write(f"SAVE_DTYPE={save_dtype_to_str(SAVE_DTYPE) if 'save_dtype_to_str' in globals() else str(SAVE_DTYPE)}\n")
        f.write(f"SKIP_IF_EXISTS={SKIP_IF_EXISTS}\n")

    # 2) 全ファイルを正規化して同構造で保存
    print("\n[2/2] normalizing and saving dataset ...")
    save_normalized_dataset(
        dataset_root=DATASET_ROOT,
        out_root=RUN_OUT_DIR,
        mean=global_mean,
        std=global_std,
        offset_t0=OFFSET_T0,
        offset_t1=OFFSET_T1,
        save_dtype=SAVE_DTYPE,
        skip_if_exists=SKIP_IF_EXISTS
    )
    print(f"\nAll done. Normalized dataset saved under: {RUN_OUT_DIR}")


DATASET_ROOT : All_materials
OUTPUT_DIR   : normalized_dataset_forplot/20251215_161803
OFFSET range : 0..100 (inclusive)

[1/2] computing global mean/std ...
[global-stats] processed 200/250 files...
[1/2] done.
  mean shape: (16, 3) std shape: (16, 3)
  std  min/max: 280.9351501464844 4830.4208984375

[2/2] normalizing and saving dataset ...
[save] total files to process: 250
[save] processed 200/250 files...
[save] done.

All done. Normalized dataset saved under: normalized_dataset_forplot/20251215_161803
