In [None]:
# ============================================================
# MM-Fit Trial Extractor (rebuild meta.csv + npz)  [ALL REPS]
# - Device: sw_r (smartwatch right)
# - Exercises: pushups, lunges, dumbbell_rows
# - Label mapping: RELATIVE time (frame/fps) <-> IMU relative seconds
# - Robust Drive I/O: per-session local cache copy
# - Output: OUT_DIR/npz/*.npz + OUT_DIR/meta_*.csv
#
# ✅ Added:
#   - session(w00~w20) -> participant mapping saved into meta + npz
#   - nonfinite(X) guard: skip OR (optional) linear interpolate per channel
# ============================================================

import os, glob, shutil, time
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

# -----------------------
# CONFIG (EDIT)
# -----------------------
DATA_ROOT = "/content/drive/MyDrive/Colab Notebooks/HAR_data/mm-fit"  # mm-fit root (w00,w01,...)
USE_DRIVE_OUTPUT = True

OUT_DIR = "/content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials"
if USE_DRIVE_OUTPUT:
    OUT_DIR = "/content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials"

CACHE_DIR = "/content/mmfit_cache"
DEVICE = "sw_r"
STREAMS = ["acc", "gyr"]

TARGET_LABELS = {"pushups", "lunges", "dumbbell_rows"}
ONLY_REPS = None          # ✅ ALL REPS (no filtering)
VIDEO_FPS = 30.0

TARGET_FS = 100.0         # extractor output fs
TRIM_SEC = 0.0

# nonfinite policy: "skip" (추천) or "fix"
NONFINITE_POLICY = "skip"

# -----------------------
# session -> participant mapping
# -----------------------
WORKOUT_TO_PARTICIPANT = {
    "w00": 2, "w01": 0, "w02": 1, "w03": 0, "w04": 1,
    "w05": 2, "w06": 0, "w07": 1, "w08": 0, "w09": 1,
    "w10": 0, "w11": 1, "w12": 3, "w13": 4, "w14": 0,
    "w15": 1, "w16": 5, "w17": 6, "w18": 7, "w19": 8,
    "w20": 9
}

# -----------------------
# dirs
# -----------------------
os.makedirs(OUT_DIR, exist_ok=True)
NPZ_DIR = os.path.join(OUT_DIR, "npz")
os.makedirs(NPZ_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

OUT_META = os.path.join(OUT_DIR, f"meta_{DEVICE}_{'_'.join(sorted(TARGET_LABELS))}.csv")

# -----------------------
# helpers
# -----------------------
def is_gdrive_path(p: str) -> bool:
    return p.startswith("/content/drive")

def list_sessions(root: str):
    return sorted([d for d in glob.glob(os.path.join(root, "w*")) if os.path.isdir(d)])

def np_load_safe(path: str, retries: int = 3):
    last = None
    for _ in range(retries):
        try:
            return np.load(path, allow_pickle=False)
        except Exception as e:
            last = e
            time.sleep(0.4)
    raise last

def read_labels_csv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, header=None)
    df.columns = ["start_frame", "end_frame", "reps", "exercise"]
    df["exercise"] = df["exercise"].astype(str).str.strip()
    df["start_frame"] = pd.to_numeric(df["start_frame"], errors="coerce")
    df["end_frame"]   = pd.to_numeric(df["end_frame"],   errors="coerce")
    df["reps"]        = pd.to_numeric(df["reps"],        errors="coerce")
    df = df.dropna(subset=["start_frame","end_frame","reps"])
    df = df[df["end_frame"] > df["start_frame"]].copy()
    return df

def parse_imu(arr: np.ndarray):
    # expected (N,5) = [device_id, timestamp(ms), x, y, z]
    if arr.ndim != 2 or arr.shape[1] < 5:
        raise ValueError(f"Unexpected IMU shape: {arr.shape}")
    ts_ms = arr[:, 1].astype(np.float64)
    xyz = arr[:, -3:].astype(np.float64)
    return ts_ms, xyz

def infer_fs(ts_ms: np.ndarray):
    d = np.diff(ts_ms)
    d = d[(d > 0) & np.isfinite(d)]
    if len(d) == 0:
        return 0.0
    md = float(np.median(d))  # ms
    return 1000.0 / md

def safe_trim_by_seconds(s: int, e: int, trim_sec: float, fs: float, n: int):
    if trim_sec <= 0 or fs <= 0:
        return s, e
    trim_n = int(round(trim_sec * fs))
    s2 = min(max(0, s + trim_n), n - 1)
    e2 = max(min(n, e - trim_n), s2 + 1)
    return s2, e2

def resample_to_fs(t_sec: np.ndarray, x: np.ndarray, target_fs: float):
    if target_fs is None or target_fs <= 0:
        return t_sec.astype(np.float32), x.astype(np.float32)

    if len(t_sec) < 2:
        return t_sec.astype(np.float32), x.astype(np.float32)

    dur = float(t_sec[-1])
    if dur <= 0:
        return t_sec.astype(np.float32), x.astype(np.float32)

    dt = 1.0 / float(target_fs)
    t_new = np.arange(0.0, dur + 1e-9, dt)
    if len(t_new) < 2:
        return t_sec.astype(np.float32), x.astype(np.float32)

    f = interp1d(t_sec, x, axis=0, kind="linear",
                 fill_value="extrapolate", bounds_error=False)
    x_new = f(t_new).astype(np.float32)
    return t_new.astype(np.float32), x_new

def ensure_cached_session(sid: str):
    src_dir = os.path.join(DATA_ROOT, sid)
    if not os.path.isdir(src_dir):
        raise FileNotFoundError(f"Session dir not found: {src_dir}")

    if not is_gdrive_path(src_dir):
        return src_dir

    dst_dir = os.path.join(CACHE_DIR, sid)
    os.makedirs(dst_dir, exist_ok=True)

    need = [f"{sid}_labels.csv"] + [f"{sid}_{DEVICE}_{st}.npy" for st in STREAMS]
    for fn in need:
        src = os.path.join(src_dir, fn)
        dst = os.path.join(dst_dir, fn)
        if not os.path.exists(src):
            continue
        if not os.path.exists(dst):
            shutil.copy2(src, dst)
    return dst_dir

def is_finite(X: np.ndarray) -> bool:
    return np.isfinite(X).all()

def fix_nonfinite_linear(X: np.ndarray) -> np.ndarray:
    """
    X(T,C)에서 NaN/Inf가 있으면 time 축 기준으로 채널별 선형 보간.
    양 끝은 nearest처럼 채움.
    """
    X = np.asarray(X, dtype=np.float64)
    T, C = X.shape
    t = np.arange(T, dtype=np.float64)
    X2 = X.copy()

    for c in range(C):
        y = X2[:, c]
        m = np.isfinite(y)
        if m.all():
            continue
        if m.sum() < 2:
            # 유효 값 거의 없으면 0으로
            y[:] = 0.0
            X2[:, c] = y
            continue
        y_interp = np.interp(t, t[m], y[m])
        X2[:, c] = y_interp

    return X2.astype(np.float32)

# -----------------------
# main extraction
# -----------------------
sessions = [os.path.basename(d) for d in list_sessions(DATA_ROOT)]
print(f"Found {len(sessions)} sessions under: {DATA_ROOT}")

records = []
bad = []
skipped_trials = []

for sid in sessions:
    try:
        work_dir = ensure_cached_session(sid)

        label_path = os.path.join(work_dir, f"{sid}_labels.csv")
        acc_path   = os.path.join(work_dir, f"{sid}_{DEVICE}_acc.npy")
        gyr_path   = os.path.join(work_dir, f"{sid}_{DEVICE}_gyr.npy")

        if not os.path.exists(label_path):
            bad.append((sid, "no_labels"))
            continue
        if not os.path.exists(acc_path) or not os.path.exists(gyr_path):
            bad.append((sid, "missing_acc_or_gyr"))
            continue

        participant = WORKOUT_TO_PARTICIPANT.get(sid, None)
        if participant is None:
            bad.append((sid, "no_participant_mapping"))
            continue

        lab = read_labels_csv(label_path)
        lab = lab[lab["exercise"].isin(TARGET_LABELS)].copy()

        if ONLY_REPS is not None:
            lab = lab[lab["reps"].round().astype(int) == int(ONLY_REPS)].copy()

        if len(lab) == 0:
            continue

        ts_a, acc_all = parse_imu(np_load_safe(acc_path))
        ts_g, gyr_all = parse_imu(np_load_safe(gyr_path))

        t_rel_a = (ts_a - ts_a[0]) / 1000.0
        t_rel_g = (ts_g - ts_g[0]) / 1000.0

        fs_a = infer_fs(ts_a)
        fs_g = infer_fs(ts_g)
        fs_native = float(np.nanmean([fs_a, fs_g]))

        lab = lab.sort_values("start_frame").reset_index(drop=True)
        print(f"\n[{sid}] participant={participant} | sets={len(lab)} | fs_native~{fs_native:.2f}Hz")

        for set_id, r in lab.iterrows():
            sf = float(r["start_frame"])
            ef = float(r["end_frame"])
            reps = float(r["reps"])
            ex = str(r["exercise"]).strip()

            t0 = sf / VIDEO_FPS
            t1 = ef / VIDEO_FPS
            if t1 <= t0:
                continue

            s_a = int(np.searchsorted(t_rel_a, t0, side="left"))
            e_a = int(np.searchsorted(t_rel_a, t1, side="left"))
            s_g = int(np.searchsorted(t_rel_g, t0, side="left"))
            e_g = int(np.searchsorted(t_rel_g, t1, side="left"))

            n_a, n_g = len(t_rel_a), len(t_rel_g)
            s_a = max(0, min(n_a - 1, s_a)); e_a = max(s_a + 1, min(n_a, e_a))
            s_g = max(0, min(n_g - 1, s_g)); e_g = max(s_g + 1, min(n_g, e_g))

            s_a, e_a = safe_trim_by_seconds(s_a, e_a, TRIM_SEC, fs_a, n_a)
            s_g, e_g = safe_trim_by_seconds(s_g, e_g, TRIM_SEC, fs_g, n_g)

            acc_seg = acc_all[s_a:e_a].astype(np.float32)
            gyr_seg = gyr_all[s_g:e_g].astype(np.float32)

            t_seg = t_rel_a[s_a:e_a]
            t_seg = (t_seg - t_seg[0]).astype(np.float32)

            T = min(len(acc_seg), len(gyr_seg), len(t_seg))
            if T < 8:
                continue
            acc_seg = acc_seg[:T]
            gyr_seg = gyr_seg[:T]
            t_seg = t_seg[:T]

            X = np.concatenate([acc_seg, gyr_seg], axis=1).astype(np.float32)  # (T,6)

            # nonfinite guard (before resample)
            if not is_finite(X):
                if NONFINITE_POLICY == "fix":
                    X = fix_nonfinite_linear(X)
                else:
                    skipped_trials.append((sid, set_id, ex, "nonfinite_before_resample"))
                    continue
            if not is_finite(X):
                skipped_trials.append((sid, set_id, ex, "still_nonfinite_before_resample"))
                continue

            # resample to TARGET_FS
            t_rs, X_rs = resample_to_fs(t_seg, X, TARGET_FS)
            fs_out = float(TARGET_FS)

            # nonfinite guard (after resample)
            if not is_finite(X_rs):
                if NONFINITE_POLICY == "fix":
                    X_rs = fix_nonfinite_linear(X_rs)
                else:
                    skipped_trials.append((sid, set_id, ex, "nonfinite_after_resample"))
                    continue
            if not is_finite(X_rs):
                skipped_trials.append((sid, set_id, ex, "still_nonfinite_after_resample"))
                continue

            out_name = f"{sid}_set{set_id:03d}_{ex}_reps{int(round(reps))}.npz"
            out_path = os.path.join(NPZ_DIR, out_name)

            np.savez_compressed(
                out_path,
                X=X_rs,
                t=t_rs,
                session=sid,
                participant=int(participant),   # ✅ added
                set_id=int(set_id),
                exercise=ex,
                reps=float(reps),
                start_frame=float(sf),
                end_frame=float(ef),
                fs=float(fs_out),
                fs_native=float(fs_native),
                device=DEVICE
            )

            records.append({
                "session": sid,
                "participant": int(participant),  # ✅ added
                "set_id": int(set_id),
                "exercise": ex,
                "reps": float(reps),
                "start_frame": float(sf),
                "end_frame": float(ef),
                "T": int(len(t_rs)),
                "dur_sec": float(t_rs[-1]) if len(t_rs) > 1 else 0.0,
                "fs": float(fs_out),
                "fs_native": float(fs_native),
                "device": DEVICE,
                "npz_path": out_path,
            })

    except Exception as e:
        bad.append((sid, f"error:{repr(e)[:160]}"))
        continue

meta = pd.DataFrame(records)

print("\n--- DONE ---")
print("Bad sessions (first 20):", bad[:20])
print("Skipped trials (first 20):", skipped_trials[:20])
print("Total trials saved:", len(meta))

if len(meta) > 0:
    meta.to_csv(OUT_META, index=False)
    print("Meta CSV:", OUT_META)
    print("NPZ dir:", NPZ_DIR)
    print("\nCounts by exercise:")
    print(meta["exercise"].value_counts())
    print("\nParticipants count:")
    print(meta["participant"].value_counts().sort_index())
else:
    print("[WARN] No trials extracted. Check paths / device availability / mapping.")


Found 21 sessions under: /content/drive/MyDrive/Colab Notebooks/HAR_data/mm-fit

[w00] participant=2 | sets=9 | fs_native~100.00Hz


  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo



[w01] participant=0 | sets=10 | fs_native~100.00Hz

[w02] participant=1 | sets=9 | fs_native~100.00Hz

[w03] participant=0 | sets=9 | fs_native~100.00Hz

[w04] participant=1 | sets=8 | fs_native~100.00Hz

[w05] participant=2 | sets=9 | fs_native~100.00Hz

[w06] participant=0 | sets=9 | fs_native~100.00Hz

[w07] participant=1 | sets=9 | fs_native~100.00Hz

[w08] participant=0 | sets=9 | fs_native~100.00Hz

[w09] participant=1 | sets=9 | fs_native~100.00Hz

[w10] participant=0 | sets=9 | fs_native~100.00Hz

[w11] participant=1 | sets=9 | fs_native~100.00Hz


  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo



[w12] participant=3 | sets=9 | fs_native~100.00Hz

[w13] participant=4 | sets=9 | fs_native~100.00Hz

[w14] participant=0 | sets=9 | fs_native~100.00Hz

[w15] participant=1 | sets=9 | fs_native~100.00Hz

[w16] participant=5 | sets=10 | fs_native~100.00Hz

[w17] participant=6 | sets=10 | fs_native~100.00Hz

[w18] participant=7 | sets=9 | fs_native~105.56Hz

[w19] participant=8 | sets=9 | fs_native~105.56Hz

[w20] participant=9 | sets=9 | fs_native~100.00Hz

--- DONE ---
Bad sessions (first 20): []
Skipped trials (first 20): [('w00', 0, 'pushups', 'nonfinite_after_resample'), ('w11', 6, 'dumbbell_rows', 'nonfinite_after_resample')]
Total trials saved: 189
Meta CSV: /content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials/meta_sw_r_dumbbell_rows_lunges_pushups.csv
NPZ dir: /content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials/npz

Counts by exercise:
exercise
pushups          64
dumbbell_rows    63
lunges           62
Name: count, dtype: int64

Participants cou

In [1]:
# =========================
# Count-only K-auto (Multi-event) + Windowing version  (MM-Fit drop-in)
#
# ✅ This version (UPDATED for new data):
#   1) Activity-specific LOSO (single-activity subject shift)  ✅ subject = participant
#   2) Skip non-finite trials (NaN/Inf) at loading stage       ✅ skip only
#   3) Logs formatted per your template (per activity)
#
# ✅ Only changed parts:
#   - [THIS PATCH] Fold TEST Summary now aggregates across ALL trials of the test subject
#                 (mean±std for GT/Pred/Diff/k_hat/entropy + n_trials)
#                 ❗Everything else unchanged.
#
# (Model / Loss / Train / Windowing / Inference / Viz are UNCHANGED)
# =========================

import os
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading  (MM-Fit)  ✅ UPDATED: requires participant
# ---------------------------------------------------------------------
def load_mmfit_meta(meta_csv_path: str,
                    target_exercises: set,
                    only_reps=None,
                    require_device: str = "sw_r"):
    """
    meta.csv 로딩.
    ✅ NEW DATA expected columns:
      participant, session, set_id, exercise, reps, fs, npz_path
    """
    if not os.path.exists(meta_csv_path):
        raise FileNotFoundError(f"[MM-Fit] meta csv not found: {meta_csv_path}")

    meta = pd.read_csv(meta_csv_path)

    # ✅ participant required
    need_cols = ["participant", "session", "set_id", "exercise", "reps", "fs", "npz_path"]
    for c in need_cols:
        if c not in meta.columns:
            raise ValueError(f"[MM-Fit] meta csv missing column: {c} (have={list(meta.columns)})")

    meta["exercise"] = meta["exercise"].astype(str).str.strip()
    meta["session"] = meta["session"].astype(str).str.strip()

    # participant sanitize
    meta["participant"] = pd.to_numeric(meta["participant"], errors="coerce")
    meta = meta.dropna(subset=["participant"]).copy()
    meta["participant"] = meta["participant"].astype(int)

    # device 필터 (device 컬럼 있을 때만)
    if "device" in meta.columns and require_device is not None:
        meta = meta[meta["device"].astype(str).str.strip() == require_device].copy()

    # exercise 필터
    meta = meta[meta["exercise"].isin(set(target_exercises))].copy()

    # reps 필터 (원하면)
    if only_reps is not None:
        meta = meta[meta["reps"].round().astype(int) == int(only_reps)].copy()

    # npz_path 존재 확인
    meta["npz_path"] = meta["npz_path"].astype(str)
    ok = meta["npz_path"].apply(lambda p: os.path.exists(p))
    meta = meta[ok].copy()

    meta = meta.sort_values(["participant", "session", "exercise", "set_id"]).reset_index(drop=True)
    return meta


def _is_finite_np(x: np.ndarray) -> bool:
    return np.isfinite(x).all()


def prepare_trial_list_mmfit(meta_rows: pd.DataFrame,
                            expected_fs: float,
                            skip_nonfinite: bool = True,
                            verbose_skip: bool = True):
    """
    - 각 row = 1 trial
    - count label = reps
    - per-trial z-score
    - ✅ nonfinite(X에 NaN/Inf)면 스킵 (skip only)
    """
    trial_list = []
    skipped = []  # (reason, npz_path)

    for _, r in meta_rows.iterrows():
        npz_path = str(r["npz_path"])
        try:
            d = np.load(npz_path, allow_pickle=True)
        except Exception:
            skipped.append(("npz_load_fail", npz_path))
            continue

        if "X" not in d.files:
            skipped.append(("missing_X", npz_path))
            continue

        X = d["X"].astype(np.float32)  # (T,C)
        reps = float(r["reps"])
        sid = str(r["session"])
        pid = int(r["participant"])
        set_id = int(r["set_id"])
        ex = str(r["exercise"])

        # fs check
        try:
            fs_npz = float(d["fs"]) if "fs" in d.files else float(r["fs"])
        except Exception:
            fs_npz = float(r["fs"])

        if abs(fs_npz - float(expected_fs)) > 1e-3:
            raise ValueError(
                f"[MM-Fit] fs mismatch for {npz_path}: fs_npz={fs_npz}, expected_fs={expected_fs}. "
                f"-> extractor 단계에서 TARGET_FS를 expected_fs로 통일하세요."
            )

        # ✅ skip nonfinite raw X
        if skip_nonfinite and (not _is_finite_np(X)):
            skipped.append(("nonfinite_X", npz_path))
            continue

        # per-trial z-score
        mean = X.mean(axis=0)
        std = X.std(axis=0)
        std = np.where(std < 1e-6, 1e-6, std).astype(np.float32)
        norm_np = (X - mean) / std

        # ✅ (안전) 정규화 후에도 nonfinite면 스킵
        if skip_nonfinite and (not _is_finite_np(norm_np)):
            skipped.append(("nonfinite_after_norm", npz_path))
            continue

        trial_list.append({
            "data": norm_np,              # (T,C)
            "count": float(reps),         # trial total count
            "meta": f"subj{pid}_{sid}_set{set_id:02d}_{ex}",
            "participant": pid,
            "session": sid,
            "exercise": ex,
            "npz_path": npz_path,
        })

    if verbose_skip and len(skipped) > 0:
        print(f"[MM-Fit] Skipped {len(skipped)} trials (nonfinite/load issues). Examples:")
        for i in range(min(5, len(skipped))):
            print("  -", skipped[i][0], ":", skipped[i][1])

    return trial_list, skipped


# ---------------------------------------------------------------------
# 2.5) ✅ Windowing (UNCHANGED)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate (UNCHANGED)
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model (UNCHANGED)
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p  # (B,T,K)

        micro_rate_t = amp_t  # (B,T)

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils (UNCHANGED)
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# ---------------------------------------------------------------------
# 5) Train (UNCHANGED)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = max(1, len(loader))
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (UNCHANGED)
# ---------------------------------------------------------------------
def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# 7) LOSO runner (per activity)  ✅ UPDATED: subject = participant
# ---------------------------------------------------------------------
def run_loso_for_one_activity(meta_all: pd.DataFrame, activity: str, CONFIG: dict, device):
    meta_act = meta_all[meta_all["exercise"] == activity].copy()
    if len(meta_act) == 0:
        print(f"[WARN] No trials for activity={activity}")
        return None

    # ✅ subject shift 기준: participant
    subjects = sorted(meta_act["participant"].unique().tolist())

    print("\n" + "-"*80)
    print(f" >>> Starting LOSO (count-only, K-auto) + WINDOWING")
    print("-"*80)

    loso_results = []
    total_skipped = 0

    for fold_idx, test_subj in enumerate(subjects):
        set_strict_seed(CONFIG["seed"])

        train_meta = meta_act[meta_act["participant"] != test_subj].copy()
        test_meta  = meta_act[meta_act["participant"] == test_subj].copy()

        train_trials, skipped_train = prepare_trial_list_mmfit(
            train_meta, expected_fs=CONFIG["fs"],
            skip_nonfinite=True, verbose_skip=False
        )
        test_trials, skipped_test = prepare_trial_list_mmfit(
            test_meta, expected_fs=CONFIG["fs"],
            skip_nonfinite=True, verbose_skip=False
        )
        total_skipped += (len(skipped_train) + len(skipped_test))

        if len(train_trials) == 0 or len(test_trials) == 0:
            print(f"Fold {fold_idx+1:2d} | Test: subject{test_subj} | [SKIP] no valid trials after filtering (nonfinite/empty).")
            continue

        train_data = trial_list_to_windows(
            train_trials,
            fs=CONFIG["fs"],
            win_sec=CONFIG["win_sec"],
            stride_sec=CONFIG["stride_sec"],
            drop_last=CONFIG["drop_last"]
        )

        if len(train_data) == 0:
            print(f"Fold {fold_idx+1:2d} | Test: subject{test_subj} | [SKIP] no windows generated.")
            continue

        train_loader = DataLoader(
            TrialDataset(train_data),
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            collate_fn=collate_variable_length,
            num_workers=0
        )

        input_ch = train_data[0]["data"].shape[1]
        model = KAutoCountModel(
            input_ch=input_ch,
            hidden_dim=CONFIG["hidden_dim"],
            latent_dim=CONFIG["latent_dim"],
            K_max=CONFIG["K_max"]
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

        for _ in range(CONFIG["epochs"]):
            _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
            scheduler.step()

        model.eval()

        fold_abs_err_sum = 0.0
        fold_res_str = ""

        # =========================
        # ✅ PATCH: aggregate per-trial stats for Fold TEST Summary
        # =========================
        gt_list, pred_list, diff_list = [], [], []
        khat_list, ent_list = [], []

        for item in test_trials:
            x_np = item["data"]

            count_pred_win, _ = predict_count_by_windowing(
                model,
                x_np=x_np,
                fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                device=device,
                tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            count_gt = float(item["count"])
            abs_err = abs(count_pred_win - count_gt)
            fold_abs_err_sum += abs_err
            fold_res_str += f"[Pred(win): {count_pred_win:.1f} / GT: {count_gt:.0f}]"

            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
            with torch.no_grad():
                _, _, _, aux = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()  # (T,K)
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)

            gt_list.append(count_gt)
            pred_list.append(float(count_pred_win))
            diff_list.append(float(count_pred_win - count_gt))
            khat_list.append(float(k_hat))
            ent_list.append(float(ent))

        fold_mae = fold_abs_err_sum / max(1, len(test_trials))
        loso_results.append(fold_mae)

        # ✅ 요청한 로그 포맷 (subject{participant_id})
        print(f"Fold {fold_idx+1:2d} | Test: subject{test_subj} | MAE: {fold_mae:.2f} | {fold_res_str}")

        if len(gt_list) > 0:
            gt_m, gt_s     = float(np.mean(gt_list)),   float(np.std(gt_list))
            pred_m, pred_s = float(np.mean(pred_list)), float(np.std(pred_list))
            diff_m, diff_s = float(np.mean(diff_list)), float(np.std(diff_list))
            k_m, k_s       = float(np.mean(khat_list)), float(np.std(khat_list))
            e_m, e_s       = float(np.mean(ent_list)),  float(np.std(ent_list))

            print(
                f"[Fold TEST Summary] subject{test_subj} | "
                f"GT={gt_m:.2f}±{gt_s:.2f} | Pred(win)={pred_m:.2f}±{pred_s:.2f} | "
                f"Diff={diff_m:+.2f}±{diff_s:.2f} | k_hat(full)={k_m:.2f}±{k_s:.2f} | "
                f"phase_entropy(full)={e_m:.3f}±{e_s:.3f} | n_trials={len(gt_list)}"
            )

    if len(loso_results) == 0:
        print("-"*80)
        print(" >>> No valid folds were evaluated.")
        print("-"*80)
        return None

    print("-"*80)
    print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
    print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
    print("-"*80)

    return {
        "activity": activity,
        "mean_mae": float(np.mean(loso_results)),
        "std_mae": float(np.std(loso_results)),
        "n_folds": int(len(loso_results)),
        "total_skipped": int(total_skipped),
    }


# ---------------------------------------------------------------------
# 8) Main
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,

        # MM-Fit meta.csv 경로 (✅ NEW DATA: must include participant)
        "mmfit_meta_csv": "/content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials/meta_sw_r_dumbbell_rows_lunges_pushups.csv",

        "MMFIT_TARGET_EXERCISES": {"pushups", "lunges", "dumbbell_rows"},
        "MMFIT_ONLY_REPS": None,
        "MMFIT_REQUIRE_DEVICE": "sw_r",

        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,

        "fs": 100,

        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    meta_all = load_mmfit_meta(
        meta_csv_path=CONFIG["mmfit_meta_csv"],
        target_exercises=CONFIG["MMFIT_TARGET_EXERCISES"],
        only_reps=CONFIG["MMFIT_ONLY_REPS"],
        require_device=CONFIG["MMFIT_REQUIRE_DEVICE"],
    )
    if len(meta_all) == 0:
        print("[MM-Fit] meta is empty after filtering. Check meta_csv / filters.")
        return

    # ✅ 활동별 단일-activity LOSO 실행 (subject=participant)
    results = []
    for act in ["pushups", "lunges", "dumbbell_rows"]:
        out = run_loso_for_one_activity(meta_all, act, CONFIG, device)
        if out is not None:
            results.append(out)

    if len(results) > 0:
        print("\n" + "="*80)
        print("Activity Summary (single-activity LOSO, subject=participant)")
        print("="*80)
        for r in results:
            print(
                f"- {r['activity']}: meanMAE={r['mean_mae']:.3f}, std={r['std_mae']:.3f}, "
                f"folds={r['n_folds']}, skipped={r['total_skipped']}"
            )
        print("="*80)


if __name__ == "__main__":
    main()


Device: cuda

--------------------------------------------------------------------------------
 >>> Starting LOSO (count-only, K-auto) + WINDOWING
--------------------------------------------------------------------------------
Fold  1 | Test: subject0 | MAE: 1.38 | [Pred(win): 7.2 / GT: 10][Pred(win): 10.0 / GT: 10][Pred(win): 9.9 / GT: 10][Pred(win): 9.6 / GT: 10][Pred(win): 8.6 / GT: 10][Pred(win): 7.6 / GT: 10][Pred(win): 9.8 / GT: 10][Pred(win): 11.9 / GT: 10][Pred(win): 10.8 / GT: 10][Pred(win): 7.5 / GT: 10][Pred(win): 8.1 / GT: 10][Pred(win): 7.8 / GT: 10][Pred(win): 11.4 / GT: 10][Pred(win): 9.2 / GT: 10][Pred(win): 8.9 / GT: 10][Pred(win): 9.0 / GT: 10][Pred(win): 7.6 / GT: 10][Pred(win): 8.4 / GT: 10]
[Fold TEST Summary] subject0 | GT=10.00±0.00 | Pred(win)=9.07±1.34 | Diff=-0.93±1.34 | k_hat(full)=1.15±0.07 | phase_entropy(full)=0.197±0.064 | n_trials=18
Fold  2 | Test: subject1 | MAE: 1.71 | [Pred(win): 8.5 / GT: 10][Pred(win): 8.7 / GT: 10][Pred(win): 8.8 / GT: 10][Pred(w