# 1) Main plot
- `Knees bending-Frontal elevation of arms`
- `Waist bends forward-Jump front & back`
- `Jogging-Running`

In [2]:
# =========================
# ✅ UPDATED (comment-based): only 2 plots per scenario
#   1) Window-level rep_rate mean±std (+ boundary + A/B shading)
#   2) Latent trajectory PCA2 (scatter + mean trajectory line + boundary marker + start/end + std circles)
#   - NO titles
#   - Bigger fonts
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns  # kept
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d  # kept
from matplotlib.patches import Circle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from matplotlib.collections import LineCollection


# ---------------------------------------------------------------------
# ✅ Save helpers
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)

def _safe_fname(s: str):
    s = str(s)
    for ch in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\n', '\t']:
        s = s.replace(ch, '_')
    s = s.strip().replace(' ', '_')
    while '__' in s:
        s = s.replace('__', '_')
    return s


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# ✅ Mixed A-B builder
# ---------------------------------------------------------------------
def get_single_trial_from_full_data(subj, act_id, gt_count, full_data, target_map, feature_map, normalize=True):
    act_name = target_map.get(act_id)
    feats = feature_map.get(act_id)

    if subj not in full_data or act_name not in full_data[subj]:
        return None

    raw_df = full_data[subj][act_name][feats]
    x = raw_df.values.astype(np.float32)

    if normalize:
        mean = x.mean(axis=0)
        std = x.std(axis=0) + 1e-6
        x = (x - mean) / std

    return {
        "data": x,
        "count": float(gt_count),
        "meta": f"{subj}_{act_name}"
    }


def build_mixed_ab_trial(subj, actA_id, actB_id, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = gtA + gtB

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    boundary = int(xA_raw.shape[0])

    x_mix_raw = np.concatenate([xA_raw, xB_raw], axis=0).astype(np.float32)

    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA = (xA_raw - mean) / std
    xB = (xB_raw - mean) / std

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    return {
        "data": x_mix,
        "count": float(gt_total),
        "meta": f"{subj}__{actA_name}__TO__{actB_name}",
        "boundary": boundary,
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A": xA,
        "data_B": xB,
        "T_A": int(xA.shape[0]),
        "T_B": int(xB.shape[0]),
    }


# ---------------------------------------------------------------------
# 2.5) Windowing
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# Helpers for scenario-avg plots
# ---------------------------------------------------------------------
def get_window_rate_curve(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]

    if T <= win_len:
        pred_count, rates = predict_count_by_windowing(
            model, x_np, fs, win_sec, stride_sec, device, tau=tau, batch_size=batch_size
        )
        t_cent = np.array([0.5 * (T / float(fs))], dtype=np.float32)
        return t_cent, rates.astype(np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())
    rates = np.concatenate(rates, axis=0).astype(np.float32)

    centers = np.array([(st + 0.5 * win_len) / float(fs) for st in starts], dtype=np.float32)
    return centers, rates


def resample_1d(y, new_len):
    y = np.asarray(y, dtype=np.float32)
    if y.size == 0:
        return np.zeros((new_len,), dtype=np.float32)
    if y.size == 1:
        return np.full((new_len,), float(y[0]), dtype=np.float32)
    x_old = np.linspace(0.0, 1.0, num=y.size, dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    return np.interp(x_new, x_old, y).astype(np.float32)


def resample_2d(Y, new_len):
    Y = np.asarray(Y, dtype=np.float32)
    if Y.shape[0] == 0:
        return np.zeros((new_len, Y.shape[1]), dtype=np.float32)
    if Y.shape[0] == 1:
        return np.repeat(Y, repeats=new_len, axis=0).astype(np.float32)
    x_old = np.linspace(0.0, 1.0, num=Y.shape[0], dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    out = []
    for d in range(Y.shape[1]):
        out.append(np.interp(x_new, x_old, Y[:, d]))
    return np.stack(out, axis=1).astype(np.float32)


def global_pca2(Z_all):
    Z_all = np.asarray(Z_all, dtype=np.float32)
    mu = Z_all.mean(axis=0, keepdims=True)
    Zc = Z_all - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    V2 = Vt[:2].astype(np.float32)
    return mu.squeeze(0).astype(np.float32), V2


# ---------------------------------------------------------------------
# ✅ UPDATED PLOTTER: comment-based 2 plots only, no titles, bigger fonts
#   (1) Window mean±std + boundary + shading
#   (2) PCA2 latent: scatter + mean trajectory line + boundary marker + start/end + std circles
# ---------------------------------------------------------------------
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable  # ✅ 추가 import (plot 함수에서만 사용)

def plot_scenario_mean_std_two_plots(
    save_dir,
    scen_name,
    win_t, win_mean, win_std,
    lat_xy_mean, lat_xy_std,
    n_subjects,
    dpi=600,
):
    _ensure_dir(save_dir if save_dir is not None else "scenario_viz")
    out_dir = save_dir if save_dir is not None else "scenario_viz"
    scen_tag = _safe_fname(scen_name)

    # -------------------------
    # (1) Window-level Rep rate mean±std
    # -------------------------
    fig1 = plt.figure(figsize=(26.0, 9.2))
    ax = fig1.gca()

    # subtle A/B shading (keep but slightly lighter)
    ax.axvspan(0.0, 1.0, alpha=0.035, color="k", lw=0)
    ax.axvspan(1.0, 2.0, alpha=0.022, color="k", lw=0)

    ax.plot(
        win_t, win_mean,
        linewidth=7.8,
        solid_capstyle="round",
        label="Window rep rate (mean)"
    )
    ax.fill_between(
        win_t,
        win_mean - win_std, win_mean + win_std,
        alpha=0.16,  # ✅ band alpha down
        linewidth=0,
        label="±1 std"
    )

    ax.axvline(
        1.0,
        linestyle="--",
        linewidth=4.2,  # ✅ boundary thinner but still bold
        label="Boundary"
    )

    ax.set_xlabel("Normalized time (A:0–1, B:1–2)", fontsize=54, labelpad=10)
    ax.set_ylabel("Rep rate (reps/s)", fontsize=54, labelpad=10)  # ✅ underscore 제거
    ax.tick_params(axis="both", labelsize=46, width=1.6, length=7)

    # paper-like spine thickness
    for sp in ax.spines.values():
        sp.set_linewidth(1.6)

    # legend: 조금만 줄이기
    ax.legend(
        fontsize=42,
        frameon=True,
        loc="upper left",
        borderpad=0.55,
        labelspacing=0.35,
        handlelength=2.4
    )

    fig1.tight_layout()
    fig1.savefig(os.path.join(out_dir, f"{scen_tag}__1_win_rep_rate.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig1)

    # -------------------------
    # (2) Latent trajectory PCA2 (gradient line)
    # -------------------------
    fig2 = plt.figure(figsize=(10.0, 10.0))
    ax = fig2.gca()

    lat_xy_mean = np.asarray(lat_xy_mean, dtype=np.float32)
    T = lat_xy_mean.shape[0]
    t_lat = np.linspace(0.0, 2.0, num=T, dtype=np.float32)

    pts = lat_xy_mean.reshape(-1, 1, 2)
    segs = np.concatenate([pts[:-1], pts[1:]], axis=1)

    lc = LineCollection(
        segs,
        array=t_lat[:-1],
        linewidth=2.6,          # ✅ trajectory 더 또렷하게
        alpha=0.95
    )
    ax.add_collection(lc)

    ax.set_xlabel("PC1", fontsize=25, labelpad=10)
    ax.set_ylabel("PC2", fontsize=25, labelpad=10)
    ax.tick_params(axis="both", labelsize=20, width=1.4, length=6)

    for sp in ax.spines.values():
        sp.set_linewidth(1.4)

    # ✅ equal 유지 + 중앙 고정
    ax.set_aspect("equal", adjustable="box")
    ax.set_anchor("C")

    # square bounds 유지
    x = lat_xy_mean[:, 0]; y = lat_xy_mean[:, 1]
    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    cx = 0.5 * (xmin + xmax); cy = 0.5 * (ymin + ymax)
    rx = 0.5 * (xmax - xmin); ry = 0.5 * (ymax - ymin)
    r = max(rx, ry) * 1.12 + 1e-6
    ax.set_xlim(cx - r, cx + r)
    ax.set_ylim(cy - r, cy + r)

    # ✅ colorbar가 axis를 “쪼그라뜨리는” 느낌을 줄이기 위해 divider로 별도 축 생성
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="4.5%", pad=0.10)
    cbar = fig2.colorbar(lc, cax=cax)
    cbar.set_label("Normalized time (0→2)", fontsize=22, labelpad=10)
    cbar.ax.tick_params(labelsize=18, width=1.2, length=5)

    fig2.tight_layout()
    fig2.savefig(os.path.join(out_dir, f"{scen_tag}__2_latent_pca2.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)


# ---------------------------------------------------------------------
# Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# ---------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)
        z = z.transpose(1, 2)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# Main (Scenario-wise LOSO)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        "TARGET_ACTIVITIES_MAP": {
            6:  'Waist bends forward',
            7:  'Frontal elevation of arms',
            8:  'Knees bending',
            10:  'Jogging',
            11: 'Running',
            12: 'Jump front & back',
        },

        "ACT_FEATURE_MAP": {
            6:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            7:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            8:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            10: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            11: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            12: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        "GT_BY_ACT": {
            6:  {"subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                 "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20},
            7:  {"subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                 "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20},
            8:  {"subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                 "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21},
            10: {"subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                 "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156},
            11: {"subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                 "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172},
            12: {"subject1": 20, "subject2": 22, "subject3": 21, "subject4": 21, "subject5": 20,
                 "subject6": 21, "subject7": 19, "subject8": 20, "subject9": 20, "subject10": 20},
        },

        "MIX_SCENARIOS": [
            ("Knees bending-Frontal elevation of arms", 8, 7),
            ("Waist bends forward-Jump front & back", 6, 12),
            ("Jogging-Running", 10, 11),
        ],
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    VIZ_DIR = "scenario_viz"
    os.makedirs(VIZ_DIR, exist_ok=True)

    # resample lengths (fixed for averaging)
    NW = 60     # window-rate points per segment (A,B)
    NL = 200    # latent points per segment (A,B)

    for scen_idx, (scen_name, actA_id, actB_id) in enumerate(CONFIG.get("MIX_SCENARIOS", []), start=1):
        TRAIN_ACT_ID = actA_id

        print("\n" + "=" * 100)
        print(f"1. 시나리오{scen_idx}: {scen_name}")
        print("=" * 100)

        loso_results = []
        scenario_records = []
        scenario_diffs = []
        scenario_diffs_A = []
        scenario_diffs_B = []

        # ✅ buffers for TWO plots only
        buf_win_concat = []     # (2*NW,)
        buf_z_concat = []       # list of (Tmix,D)
        per_subject_latent = [] # list of {"zA","zB"}

        for fold_idx, test_subj in enumerate(subjects):
            set_strict_seed(CONFIG["seed"])

            gt_train_map = CONFIG["GT_BY_ACT"][TRAIN_ACT_ID]
            train_labels = [(s, TRAIN_ACT_ID, gt_train_map[s]) for s in subjects if s != test_subj]
            test_labels  = [(test_subj, TRAIN_ACT_ID, gt_train_map[test_subj])]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            if not test_trials or not train_trials:
                continue

            train_data = trial_list_to_windows(
                train_trials, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_data[0]['data'].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for epoch in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval
            model.eval()

            # (A) LOSO on TRAIN_ACT_ID
            item = test_trials[0]
            x_np = item["data"]
            count_gt = float(item["count"])
            count_pred_win, _ = predict_count_by_windowing(
                model, x_np=x_np,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            fold_mae = float(abs(count_pred_win - count_gt))
            loso_results.append(fold_mae)

            # (B) Scenario eval
            mixed_item = build_mixed_ab_trial(test_subj, actA_id, actB_id, CONFIG, full_data)
            if mixed_item is None:
                continue

            x_mix = mixed_item["data"]
            boundary = mixed_item["boundary"]
            gt_total = float(mixed_item["count"])
            md = mixed_item["meta_detail"]

            pred_total, _ = predict_count_by_windowing(
                model, x_np=x_mix,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = float(pred_total - gt_total)
            mae = float(abs(diff))
            scenario_diffs.append(diff)

            # A/B split MAE
            xA = mixed_item["data_A"]
            xB = mixed_item["data_B"]
            gtA = float(md["gtA"])
            gtB = float(md["gtB"])

            pred_A, _ = predict_count_by_windowing(
                model, x_np=xA,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            pred_B, _ = predict_count_by_windowing(
                model, x_np=xB,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff_A = float(pred_A - gtA)
            diff_B = float(pred_B - gtB)
            mae_A = float(abs(diff_A))
            mae_B = float(abs(diff_B))
            scenario_diffs_A.append(diff_A)
            scenario_diffs_B.append(diff_B)

            # full forward for z (latent) (+ keep entropy/k_hat log if you want)
            x_tensor = torch.tensor(x_mix, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, z_m, _, aux_m = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p_m = aux_m["phase_p"].squeeze(0).detach().cpu().numpy()  # (Tmix,K)
            k_hat_m = float(aux_m["k_hat"].item())
            ent_m = compute_phase_entropy_mean(phase_p_m)

            z_td = z_m.squeeze(0).detach().cpu().numpy()  # (Tmix,D)

            scenario_records.append({
                "subj": test_subj,
                "line": (
                    f"[Scenario] {scen_name} | {test_subj} | "
                    f"{md['actA_name']}({md['gtA']:.0f}) -> {md['actB_name']}({md['gtB']:.0f}) | "
                    f"GT_total={gt_total:.0f} | Pred_total(win)={pred_total:.2f} | Diff_total={diff:+.2f} | "
                    f"MAE_total={mae:.2f} | "
                    f"Pred_A={pred_A:.2f} (Diff_A={diff_A:+.2f}, MAE_A={mae_A:.2f}) | "
                    f"Pred_B={pred_B:.2f} (Diff_B={diff_B:+.2f}, MAE_B={mae_B:.2f}) | "
                    f"k_hat(full)={k_hat_m:.2f} | ent(full)={ent_m:.3f} | boundary={boundary}"
                )
            })

            # (1) window-level rep_rate curves: A/B separately -> resample -> concat
            _, rA_w = get_window_rate_curve(
                model, xA, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rB_w = get_window_rate_curve(
                model, xB, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            rA_rs = resample_1d(rA_w, NW)
            rB_rs = resample_1d(rB_w, NW)
            buf_win_concat.append(np.concatenate([rA_rs, rB_rs], axis=0))

            # (2) latent: store raw z split for global PCA later
            zA = z_td[:boundary, :]
            zB = z_td[boundary:, :]
            buf_z_concat.append(z_td)
            per_subject_latent.append({"zA": zA, "zB": zB})

        # -------------------------
        # Print summaries (same as your original)
        # -------------------------
        print("-" * 100)
        if len(loso_results) > 0:
            print("TRain 결과 ->")
            print("-" * 100)
            print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
            print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
            print("-" * 100)
        else:
            print("TRain 결과 -> (no folds computed)")
            print("-" * 100)

        subj2line = {r["subj"]: r["line"] for r in scenario_records}
        for s in subjects:
            if s in subj2line:
                print(subj2line[s])

        print("\n" + "-" * 100)
        print(f"시나리오{scen_idx} 결과")
        print("-" * 100)

        if len(scenario_diffs) == 0:
            print(f"[{scen_name}] No samples (all skipped). Fill GT_BY_ACT for required activities.")
        else:
            diffs = np.array(scenario_diffs, dtype=np.float32)
            mae = float(np.mean(np.abs(diffs)))
            rmse = float(np.sqrt(np.mean(diffs ** 2)))
            print(f"[TOTAL] [{scen_name}] N={len(diffs):2d} | MAE={mae:.3f} | RMSE={rmse:.3f} | mean(diff)={diffs.mean():+.3f} | std(diff)={diffs.std():.3f}")

            diffsA = np.array(scenario_diffs_A, dtype=np.float32)
            maeA = float(np.mean(np.abs(diffsA)))
            rmseA = float(np.sqrt(np.mean(diffsA ** 2)))
            print(f"[A]     [{scen_name}] N={len(diffsA):2d} | MAE={maeA:.3f} | RMSE={rmseA:.3f} | mean(diff)={diffsA.mean():+.3f} | std(diff)={diffsA.std():.3f}")

            diffsB = np.array(scenario_diffs_B, dtype=np.float32)
            maeB = float(np.mean(np.abs(diffsB)))
            rmseB = float(np.sqrt(np.mean(diffsB ** 2)))
            print(f"[B]     [{scen_name}] N={len(diffsB):2d} | MAE={maeB:.3f} | RMSE={rmseB:.3f} | mean(diff)={diffsB.mean():+.3f} | std(diff)={diffsB.std():.3f}")

        print("-" * 100)

        # -------------------------
        # ✅ TWO-PLOT scenario-avg visualization
        # -------------------------
        n_ok = len(buf_win_concat)
        if n_ok == 0:
            print("[Viz Skip] No scenario samples collected for averaging.")
            continue

        # (1) window-level mean±std on normalized axis [0,2]
        W = np.stack(buf_win_concat, axis=0)  # (N, 2*NW)
        win_mean = W.mean(axis=0)
        win_std  = W.std(axis=0)
        win_t = np.linspace(0.0, 2.0, num=2*NW, dtype=np.float32)

        # (2) latent: global PCA basis from all z
        Z_all = np.concatenate(buf_z_concat, axis=0)  # (sumT, D)
        mu, V2 = global_pca2(Z_all)

        # project each subject, resample A/B to NL then concat -> (2*NL,2)
        lat_list = []
        for d in per_subject_latent:
            zA = d["zA"]; zB = d["zB"]
            pcA = (zA - mu) @ V2.T
            pcB = (zB - mu) @ V2.T
            pcA_rs = resample_2d(pcA, NL)
            pcB_rs = resample_2d(pcB, NL)
            lat_list.append(np.concatenate([pcA_rs, pcB_rs], axis=0))

        LAT = np.stack(lat_list, axis=0)  # (N,2*NL,2)
        lat_xy_mean = LAT.mean(axis=0)
        lat_xy_std  = LAT.std(axis=0)

        plot_scenario_mean_std_two_plots(
            save_dir=VIZ_DIR,
            scen_name=scen_name,
            win_t=win_t, win_mean=win_mean, win_std=win_std,
            lat_xy_mean=lat_xy_mean, lat_xy_std=lat_xy_std,
            n_subjects=n_ok,
            dpi=600
        )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

1. 시나리오1: Knees bending-Frontal elevation of arms
----------------------------------------------------------------------------------------------------
TRain 결과 ->
----------------------------------------------------------------------------------------------------
 >>> Final LOSO Result (Average MAE): 4.274
 >>> Standard Deviation: 3.012
----------------------------------------------------------------------------------------------------
[Scenario] Knees bending-Frontal elevation of arms | subject1 | Knees bending(20) -> Frontal elevation of arms(20) | GT_total=40 | Pred_total(win)=66.34 | Diff_total=+26.34 | MAE_total=26.34 | Pred_A=47.42 (Diff_A=+27.42, MAE_A=27.42) | Pred_B=18.71 (Diff_B=-1.29, MAE_B=1.29) | k_hat(full)=1.02 | ent(full)=0.059 | boundary=3379
[Scenario] Knees bending-Frontal elevation of arms | subject2 | Knees bending(21) -> Frontal elevation of arms(20) | GT_total

# 2) Integrator VS. Ours PCA plot
- `Knees bending`, `Jump front and back`, `Jogging`

In [1]:
# =========================
# PCA2D ONLY (Ours vs Integrator) — Activity-wise folders, fold-wise plots
#
# - For each activity:
#   - LOSO fold loop
#   - Train OURS model on train subjects (window-proxy supervision)
#   - Extract OURS latent z(t) from the held-out test trial (full-trial forward)
#   - Build Integrator window-feature trajectory from raw (order-invariant features)
#   - Plot PCA2D comparison in ONE figure with 2 subplots (OURS vs INTEGRATOR)
#   - NO titles, NO start/stop markers
#   - Ticks padding tightened (closer to axes)
#   - Saved as PNG (dpi=600), no plt.show()
# =========================

import os
import glob
import random
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# IO helpers
# ---------------------------------------------------------------------
def _safe_filename(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^\w\-_\. ]", "_", s)
    s = s.strip().replace(" ", "_")
    return s[:200] if len(s) > 200 else s


def _ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)
    return p


# ---------------------------------------------------------------------
# Strict seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# Data loading (mHealth)
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    """
    Returns list of dict:
      - data: z-scored (T,C)
      - raw : raw (T,C)
      - count: trial count
      - meta : string
    """
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_id = int(act_id)
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            mean = raw_np.mean(axis=0, keepdims=True)
            std = raw_np.std(axis=0, keepdims=True) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                "data": norm_np,
                "raw": raw_np,
                "count": float(gt_count),
                "meta": f"{subj}_{act_name}",
                "subj": subj,
                "act_id": act_id,
                "act_name": act_name,
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# Windowing (TRAIN) + variable-length collate
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                })

    return windows


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C,T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),    # (B,C,Tmax)
        "mask": torch.stack(masks),          # (B,Tmax)
        "count": torch.stack(counts),        # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas,
    }


# ---------------------------------------------------------------------
# OURS model (same as before but we only need z(t))
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B,D,T)
        z = z.transpose(1, 2)      # (B,T,D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B,D,T)
        return self.net(zt)        # (B,C,T)


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=128, K_max=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                 # (B,T,1+K)
        amp = F.softplus(out[..., 0])     # (B,T)
        logits = out[..., 1:]             # (B,T,K)
        phase = F.softmax(logits / tau, dim=-1)
        return amp, phase


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)        # (B,T,D)
        x_hat = self.decoder(z)    # (B,C,T)
        amp, phase = self.rate_head(z, tau=tau)  # (B,T), (B,T,K)

        # rep_rate_t (for training losses)
        p_bar = phase.mean(dim=1) if mask is None else (phase * mask.unsqueeze(-1)).sum(dim=1) / (mask.sum(dim=1, keepdim=True) + 1e-6)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)
        rep_rate_t = amp / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (supervision target)
        if mask is None:
            avg_rep = rep_rate_t.mean(dim=1)
        else:
            avg_rep = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {"rep_rate_t": rep_rate_t, "phase": phase}
        return avg_rep, z, x_hat, aux


# ---------------------------------------------------------------------
# Training (minimal)
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    se = (x_hat - x) ** 2
    se = se * mask.unsqueeze(1)
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase, mask=None, eps=1e-8):
    ent = -(phase * (phase + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase.mean(dim=1)
    else:
        m = mask.unsqueeze(-1).to(dtype=phase.dtype, device=phase.device)
        p_bar = (phase * m).sum(dim=1) / (m.sum(dim=1) + eps)
    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean()


def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.0075)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, _, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_ent = phase_entropy_loss(aux["phase"], mask)
        loss_effk = effK_usage_loss(aux["phase"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_ent * loss_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()


# ---------------------------------------------------------------------
# PCA2D helpers (no sklearn)
# ---------------------------------------------------------------------
def _fit_pca_basis(X, n_comp=2, eps=1e-6):
    X = np.asarray(X, dtype=np.float32)
    mu = X.mean(axis=0, keepdims=True)
    sig = X.std(axis=0, keepdims=True) + eps
    Xs = (X - mu) / sig
    _, _, Vt = np.linalg.svd(Xs, full_matrices=False)
    W = Vt[:n_comp].T
    return mu.squeeze(0), sig.squeeze(0), W


def _pca_project(X, mu, sig, W):
    X = np.asarray(X, dtype=np.float32)
    Xs = (X - mu[None, :]) / (sig[None, :] + 1e-12)
    return Xs @ W


def _downsample_by_index(X, t, max_N):
    N = X.shape[0]
    if N <= max_N:
        return X, t
    idx = np.linspace(0, N - 1, max_N).astype(int)
    return X[idx], t[idx]


# ---------------------------------------------------------------------
# Integrator: build window-feature trajectory (order-invariant)
# ---------------------------------------------------------------------
def integrator_window_feature_vector(x_raw: np.ndarray, fs: int, win_sec: float, stride_sec: float):
    """
    Returns t_cent (N,), F (N, 2C)
      per channel: log(1+energy), log(1+var)
    """
    x = np.asarray(x_raw, dtype=np.float32)
    T, C = x.shape
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    win_len = max(10, min(win_len, T))
    stride = max(1, stride)

    centers, feats = [], []
    for st in range(0, max(1, T - win_len + 1), stride):
        ed = st + win_len
        w = x[st:ed] - x[st:ed].mean(axis=0, keepdims=True)
        var = np.var(w, axis=0) + 1e-12
        E = np.sum(w * w, axis=0) / max(float(fs), 1e-12)
        f = np.concatenate([np.log1p(E), np.log1p(var)], axis=0)  # (2C,)
        centers.append(0.5 * (st + ed) / float(fs))
        feats.append(f.astype(np.float32))

    return np.asarray(centers, dtype=np.float32), np.stack(feats, axis=0).astype(np.float32)


# ---------------------------------------------------------------------
# Plot: PCA2D comparison (2 subplots, no titles)
# ---------------------------------------------------------------------
from mpl_toolkits.axes_grid1 import make_axes_locatable
def plot_pca2d_comparison(
    ours_z_tD, ours_t_sec,
    int_F_nD, int_t_sec,
    save_path,
    dpi=600,
    max_points=2000,
    show_colorbar=True,
    tick_pad=0,
    point_size=10,
    alpha=0.9
):
    ours_z_tD = np.asarray(ours_z_tD, dtype=np.float32)
    ours_t_sec = np.asarray(ours_t_sec, dtype=np.float32)
    int_F_nD = np.asarray(int_F_nD, dtype=np.float32)
    int_t_sec = np.asarray(int_t_sec, dtype=np.float32)

    ours_z_tD, ours_t_sec = _downsample_by_index(ours_z_tD, ours_t_sec, max_points)
    int_F_nD, int_t_sec   = _downsample_by_index(int_F_nD, int_t_sec, max_points)

    ours_c = (ours_t_sec - ours_t_sec.min()) / max(1e-6, (ours_t_sec.max() - ours_t_sec.min()))
    int_c  = (int_t_sec  - int_t_sec.min())  / max(1e-6, (int_t_sec.max()  - int_t_sec.min()))

    mu_o, sig_o, W_o = _fit_pca_basis(ours_z_tD, n_comp=2)
    P_o = _pca_project(ours_z_tD, mu_o, sig_o, W_o)

    mu_i, sig_i, W_i = _fit_pca_basis(int_F_nD, n_comp=2)
    P_i = _pca_project(int_F_nD, mu_i, sig_i, W_i)

    fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.0), sharey=True, gridspec_kw={"wspace": 0.10, "width_ratios": [1, 1]})
    ax0, ax1 = axes

    sc0 = ax0.scatter(P_o[:, 0], P_o[:, 1], c=ours_c, s=point_size, alpha=alpha, cmap="viridis")
    sc1 = ax1.scatter(P_i[:, 0], P_i[:, 1], c=int_c,  s=point_size, alpha=alpha, cmap="viridis")

    ax0.set_xlabel("PC1", fontsize=16); ax0.set_ylabel("PC2", fontsize=16)
    ax1.set_xlabel("PC1", fontsize=16); ax1.set_ylabel("", fontsize=16)

    for ax in (ax0, ax1):
        ax.tick_params(axis="both", which="major", labelsize=14, pad=tick_pad)
        ax.grid(False)

    # ✅ 여기 추가: 두 subplot에 동일한 xlim/ylim 적용
    all_x = np.concatenate([P_o[:, 0], P_i[:, 0]], axis=0)
    all_y = np.concatenate([P_o[:, 1], P_i[:, 1]], axis=0)

    x_min, x_max = float(all_x.min()), float(all_x.max())
    y_min, y_max = float(all_y.min()), float(all_y.max())

    # padding (range의 5%)
    x_pad = 0.05 * max(1e-6, (x_max - x_min))
    y_pad = 0.05 * max(1e-6, (y_max - y_min))

    xlim = (x_min - x_pad, x_max + x_pad)
    ylim = (y_min - y_pad, y_max + y_pad)

    ax0.set_xlim(*xlim); ax1.set_xlim(*xlim)
    ax0.set_ylim(*ylim); ax1.set_ylim(*ylim)

    if show_colorbar:
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("right", size="4.5%", pad=0.10)
        cbar = fig.colorbar(sc1, cax=cax)
        cbar.set_label("Normalized time (0→1)", fontsize=14, labelpad=10)
        cbar.ax.tick_params(labelsize=12, pad=tick_pad)
        cbar.outline.set_linewidth(1.2)

    _ensure_dir(os.path.dirname(save_path))
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# Main (LOSO per activity, PCA2D-only)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",
        "out_dir": "./outputs_pca2d_only_compare",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        "TARGET_ACTIVITIES_MAP": {
            8:  'Knees bending',
            10: 'Jogging',
            12: 'Jump front & back',
        },

        "ACT_FEATURE_MAP": {
            8:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            10: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            12: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
        },

        # Train params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing for training
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        # Integrator feature trajectory
        "INT_WIN_SEC": 0.6,
        "INT_STRIDE_SEC": 0.06,   # 더 촘촘하게 -> 점 개수 늘어서 blob 느낌
    }

    # Plot controls
    PLOT = {
        "dpi": 600,
        "max_points": 2500,
        "show_colorbar": True,   # 필요 없으면 False
        "tick_pad": 3,          # 축에 붙게
        "point_size": 10,
        "alpha": 0.9,
    }

    # Count-only labels (그대로 사용)
    CONFIG["ALL_LABELS"] = [
        # act 8
        ("subject1", 8, 20), ("subject2", 8, 21), ("subject3", 8, 21), ("subject4", 8, 19), ("subject5", 8, 20),
        ("subject6", 8, 20), ("subject7", 8, 21), ("subject8", 8, 21), ("subject9", 8, 21), ("subject10", 8, 21),

        # act 10
        ("subject1", 10, 157), ("subject2", 10, 161), ("subject3", 10, 154), ("subject4", 10, 154), ("subject5", 10, 160),
        ("subject6", 10, 156), ("subject7", 10, 153), ("subject8", 10, 160), ("subject9", 10, 166), ("subject10", 10, 156),

        # act 12
        ("subject1", 12, 20), ("subject2", 12, 22), ("subject3", 12, 21), ("subject4", 12, 21), ("subject5", 12, 20),
        ("subject6", 12, 21), ("subject7", 12, 19), ("subject8", 12, 20), ("subject9", 12, 20), ("subject10", 12, 20),
    ]

    _ensure_dir(CONFIG["out_dir"])

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    print("\n" + "-" * 90)
    print(" >>> PCA2D ONLY (OURS vs INTEGRATOR) — per activity, per fold")
    print("-" * 90)

    for act_id, act_name in CONFIG["TARGET_ACTIVITIES_MAP"].items():
        labels_act = [x for x in CONFIG["ALL_LABELS"] if int(x[1]) == int(act_id)]
        if len(labels_act) == 0:
            print(f"[Skip] act {act_id}: no labels.")
            continue

        subjects_act = sorted(list(set([x[0] for x in labels_act])))

        act_out_dir = _ensure_dir(os.path.join(CONFIG["out_dir"], f"act{act_id}_{_safe_filename(act_name)}"))
        print("\n" + "=" * 90)
        print(f"[Activity] {act_id}: {act_name} | #subjects={len(subjects_act)}")
        print("=" * 90)

        for fold_idx, test_subj in enumerate(subjects_act):
            set_strict_seed(CONFIG["seed"])
            fold_out_dir = _ensure_dir(os.path.join(act_out_dir, f"fold{fold_idx+1:02d}_{_safe_filename(test_subj)}"))

            train_labels = [x for x in labels_act if x[0] != test_subj]
            test_labels  = [x for x in labels_act if x[0] == test_subj]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

            if not test_trials:
                print(f"[Skip] fold {fold_idx+1}: {test_subj} no data.")
                continue

            # Train OURS
            train_windows = trial_list_to_windows(
                train_trials, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])
            train_loader = DataLoader(
                TrialDataset(train_windows),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_windows[0]["data"].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for _ in range(CONFIG["epochs"]):
                train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval()

            # Take ONE test trial (보통 1개)
            item = test_trials[0]
            x_np = item["data"]     # (T,C) z-scored
            x_raw = item["raw"]     # (T,C) raw

            # OURS z(t)
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
            with torch.no_grad():
                _, z_full, _, _ = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))
            ours_z = z_full.squeeze(0).detach().cpu().numpy()  # (T,D)
            ours_t = np.arange(ours_z.shape[0], dtype=np.float32) / float(CONFIG["fs"])

            # Integrator features trajectory
            int_t, int_F = integrator_window_feature_vector(
                x_raw, fs=CONFIG["fs"],
                win_sec=float(CONFIG["INT_WIN_SEC"]),
                stride_sec=float(CONFIG["INT_STRIDE_SEC"])
            )

            # Plot PCA2D comparison
            save_path = os.path.join(fold_out_dir, "pca2d_compare.png")
            plot_pca2d_comparison(
                ours_z_tD=ours_z, ours_t_sec=ours_t,
                int_F_nD=int_F, int_t_sec=int_t,
                save_path=save_path,
                dpi=int(PLOT["dpi"]),
                max_points=int(PLOT["max_points"]),
                show_colorbar=bool(PLOT["show_colorbar"]),
                tick_pad=float(PLOT["tick_pad"]),
                point_size=float(PLOT["point_size"]),
                alpha=float(PLOT["alpha"]),
            )

            print(f"  Fold {fold_idx+1:02d} | Test {test_subj} | saved: {save_path}")

    print("\nDone.")
    print(f"[Output root] {os.path.abspath(CONFIG['out_dir'])}")


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

------------------------------------------------------------------------------------------
 >>> PCA2D ONLY (OURS vs INTEGRATOR) — per activity, per fold
------------------------------------------------------------------------------------------

[Activity] 8: Knees bending | #subjects=10
  Fold 01 | Test subject1 | saved: ./outputs_pca2d_only_compare/act8_Knees_bending/fold01_subject1/pca2d_compare.png
  Fold 02 | Test subject10 | saved: ./outputs_pca2d_only_compare/act8_Knees_bending/fold02_subject10/pca2d_compare.png
  Fold 03 | Test subject2 | saved: ./outputs_pca2d_only_compare/act8_Knees_bending/fold03_subject2/pca2d_compare.png
  Fold 04 | Test subject3 | saved: ./outputs_pca2d_only_compare/act8_Knees_bending/fold04_subject3/pca2d_compare.png
  Fold 05 | Test subject4 | saved: ./outputs_pca2d_only_compare/act8_Knees_bending/fold05_subject4/pca2d_compare.png
  Fold 06 | Test sub

# 3) Activity shift: Success VS. Failure case PCA & line plot
- `Frontal elevation of arms→Knees bending `
- `Knees bending→Jogging `

In [18]:
# =========================
# A1 Pairwise (SUCCESS + FAILURE only) Full Code  ✅ (requested)
#
# ✅ Only 2 experiments:
#   (SUCCESS) Frontal elevation of arms (7) -> Knees bending (8)
#   (FAILURE) Knees bending (8) -> Jogging (10)
#
# ✅ A1 visualization (per experiment):
#   (1) Train-PCA basis에 Test를 투영해서 "초점/방향(manifold alignment) 차이"를 2-panel로 비교
#   (2) pred_rate(t) vs gt_rate overlay (성공/실패 둘 다 저장; 실패에서 특히 해석용)
#
# -------------------------
# "다른건 건드리지 말고" 원칙에 따라,
# ❗ 꼭 필요해서 바꾼 부분:
#   ✅ xlim, ylim만 A1_PCA_train7_test8_subject5 기준으로 "고정"되도록 추가/적용
# -------------------------
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd

from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


# ---------------------------------------------------------------------
# 0) Small helpers (dir / downsample / PCA)
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)


def _downsample_by_index(X, t, max_points=2000):
    X = np.asarray(X)
    t = np.asarray(t)
    n = len(t)
    if n <= max_points:
        return X, t
    idx = np.linspace(0, n - 1, max_points).astype(np.int64)
    return X[idx], t[idx]


def _fit_pca_basis(z_tD, n_comp=2, eps=1e-6):
    """
    z_tD: (T,D)
    returns: mu(D,), sig(D,), W(D,n_comp)
    """
    Z = np.asarray(z_tD, dtype=np.float32)
    mu = Z.mean(axis=0)
    sig = Z.std(axis=0) + eps
    Zs = (Z - mu) / sig
    # Cov (D,D)
    C = (Zs.T @ Zs) / max(1, (Zs.shape[0] - 1))
    # eig via SVD (stable)
    U, S, Vt = np.linalg.svd(C, full_matrices=False)
    W = U[:, :n_comp]
    return mu.astype(np.float32), sig.astype(np.float32), W.astype(np.float32)


def _pca_project(z_tD, mu, sig, W):
    Z = np.asarray(z_tD, dtype=np.float32)
    Zs = (Z - mu) / sig
    return (Zs @ W).astype(np.float32)


# ✅ xlim/ylim 기준(Reference) 계산 helper: A1_PCA_train7_test8_subject5 기준으로 저장
def _compute_ref_xlim_ylim_from_train_test_latents(train_z_tD, test_z_tD, pad_ratio=0.05,
                                                   q_low=1.0, q_high=99.0):
    mu, sig, W = _fit_pca_basis(train_z_tD, n_comp=2)
    P_tr = _pca_project(train_z_tD, mu, sig, W)
    P_te = _pca_project(test_z_tD,  mu, sig, W)

    all_x = np.concatenate([P_tr[:,0], P_te[:,0]])
    all_y = np.concatenate([P_tr[:,1], P_te[:,1]])

    x_min, x_max = np.percentile(all_x, [q_low, q_high])
    y_min, y_max = np.percentile(all_y, [q_low, q_high])

    x_pad = pad_ratio * max(1e-6, (x_max - x_min))
    y_pad = pad_ratio * max(1e-6, (y_max - y_min))

    return (float(x_min - x_pad), float(x_max + x_pad)), (float(y_min - y_pad), float(y_max + y_pad))


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score (trial-wise)
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,              # (T,C)
                'count': float(gt_count),      # trial total count
                'meta': f"{subj}_{act_name}",
                'subj': subj,
                'act_id': act_id,
            })
        else:
            pass

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    """
    TRAIN 전용: trial -> sliding windows
    window 라벨은 trial 평균 rate로부터 생성:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    TEST 전용: trial -> windows inference -> window rate 평균 -> total count
    x_np: (T,C) numpy (정규화된 상태)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return float(pred_count)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N,C,win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count)


# ✅ [MOD-1] 추가: pred_rate(t) 시계열까지 반환
def predict_count_and_rates_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    returns:
      pred_count: float
      t_centers_sec: (N,) window center time in sec
      rates: (N,) predicted rate per window (reps/s)
      gt_rate: float (trial-level gt_count / duration)
      total_dur: float
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        rate = float(rate_hat.item())
        pred_count = rate * total_dur
        t_centers = np.array([0.5 * total_dur], dtype=np.float32)
        rates = np.array([rate], dtype=np.float32)
        return float(pred_count), t_centers, rates, float(total_dur)

    starts = list(range(0, T - win_len + 1, stride))
    centers = np.array([st + 0.5 * win_len for st in starts], dtype=np.float32) / float(fs)

    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N,C,win_len)

    rates_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates_list.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates_list, axis=0).astype(np.float32)  # (N,)
    pred_count = float(rates.mean() * total_dur)
    return float(pred_count), centers, rates, float(total_dur)


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B,D,T)
        z = z.transpose(1, 2)      # (B,T,D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B,D,T)
        x_hat = self.net(zt)       # (B,C,T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        micro_rate_t = amp_t             # (B,T)

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,          # (B,T,K)
            "rep_rate_t": rep_rate_t,    # (B,T)
            "k_hat": k_hat,              # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)
    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean()


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)  # (T,K)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# 4.5) A1 Plots
# ---------------------------------------------------------------------
def plot_pca2d_train_test_overlay(
    train_z_tD, train_t_sec,
    test_z_tD, test_t_sec,
    save_path,
    dpi=600,
    max_points=2000,
    show_colorbar=True,
    tick_pad=0,
    point_size=10,
    alpha_test=0.9,
    alpha_train=0.35,
    train_lw=2.0,
    # ✅ NEW: 고정 축 (A1_PCA_train7_test8_subject5 기준)
    fixed_xlim=None,   # (xmin, xmax)
    fixed_ylim=None    # (ymin, ymax)
):
    """
    ✅ PCA basis is FIT on TRAIN, then BOTH train/test are projected into the SAME basis.
    ✅ Single-axis OVERLAY:
       - TRAIN: gray scatter (reference manifold)
       - TEST : time-colored scatter + single colorbar (time)
    """
    train_z_tD = np.asarray(train_z_tD, dtype=np.float32)
    train_t_sec = np.asarray(train_t_sec, dtype=np.float32)
    test_z_tD  = np.asarray(test_z_tD,  dtype=np.float32)
    test_t_sec = np.asarray(test_t_sec, dtype=np.float32)

    train_z_tD, train_t_sec = _downsample_by_index(train_z_tD, train_t_sec, max_points)
    test_z_tD,  test_t_sec  = _downsample_by_index(test_z_tD,  test_t_sec,  max_points)

    test_c = (test_t_sec - test_t_sec.min()) / max(1e-6, (test_t_sec.max() - test_t_sec.min()))

    mu, sig, W = _fit_pca_basis(train_z_tD, n_comp=2)
    P_tr = _pca_project(train_z_tD, mu, sig, W)   # (T,2)
    P_te = _pca_project(test_z_tD,  mu, sig, W)   # (T,2)

    fig, ax = plt.subplots(1, 1, figsize=(7.0, 6.0))

    ax.scatter(P_tr[:, 0], P_tr[:, 1], s=point_size, alpha=0.25, color="0.6", label="Train (ref)")
    sc = ax.scatter(P_te[:, 0], P_te[:, 1], c=test_c, s=point_size, alpha=alpha_test, cmap="viridis", label="Test (proj)")

    ax.set_xlabel("PC1", fontsize=16)
    ax.set_ylabel("PC2", fontsize=16)
    ax.tick_params(axis="both", which="major", labelsize=14, pad=tick_pad)
    ax.grid(False)

    # ✅ xlim/ylim: 고정 축이 주어지면 그대로 사용
    if fixed_xlim is not None and fixed_ylim is not None:
        ax.set_xlim(float(fixed_xlim[0]), float(fixed_xlim[1]))
        ax.set_ylim(float(fixed_ylim[0]), float(fixed_ylim[1]))
    else:
        all_x = np.concatenate([P_tr[:, 0], P_te[:, 0]], axis=0)
        all_y = np.concatenate([P_tr[:, 1], P_te[:, 1]], axis=0)

        x_min, x_max = float(all_x.min()), float(all_x.max())
        y_min, y_max = float(all_y.min()), float(all_y.max())

        x_pad = 0.05 * max(1e-6, (x_max - x_min))
        y_pad = 0.05 * max(1e-6, (y_max - y_min))

        ax.set_xlim(x_min - x_pad, x_max + x_pad)
        ax.set_ylim(y_min - y_pad, y_max + y_pad)

    if show_colorbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="4.5%", pad=0.10)
        cbar = fig.colorbar(sc, cax=cax)
        cbar.set_label("Normalized time (0→1)", fontsize=14, labelpad=10)
        cbar.ax.tick_params(labelsize=12, pad=tick_pad)
        cbar.outline.set_linewidth(1.2)

    _ensure_dir(os.path.dirname(save_path))
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


def plot_predrate_vs_gtrate(
    t_sec,
    pred_rate,
    gt_rate,
    save_path,
    dpi=600,
    tick_pad=0,
    smooth_sigma=1.0,
    lw_pred=4.0,
    lw_gt=4.0,
    legend_loc="best"
):
    """
    pred_rate: (N,) window-level predicted repetition rate (reps/s)
    gt_rate  : float, trial-level constant rate = count / duration
    """
    t_sec = np.asarray(t_sec, dtype=np.float32)
    pred_rate = np.asarray(pred_rate, dtype=np.float32)

    if smooth_sigma is not None and smooth_sigma > 0:
        pred_plot = gaussian_filter1d(pred_rate, sigma=float(smooth_sigma))
    else:
        pred_plot = pred_rate

    fig, ax = plt.subplots(1, 1, figsize=(8.5, 5.5))

    ax.plot(
        t_sec, pred_plot,
        linewidth=lw_pred,
        label="Pred rate (window-level)"
    )

    ax.axhline(
        y=float(gt_rate),
        linestyle="--",
        linewidth=lw_gt,
        label="GT rate (trial-level)"
    )

    ax.set_xlabel("Time (s)", fontsize=34)
    ax.set_ylabel("Repetition rate (reps/s)", fontsize=34)
    ax.tick_params(axis="both", which="major", labelsize=24, pad=tick_pad)
    ax.grid(False)
    ax.legend(loc=legend_loc, frameon=False, fontsize=18)

    _ensure_dir(os.path.dirname(save_path))
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# 5) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, _, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()


# ---------------------------------------------------------------------
# 6) Main (SUCCESS + FAILURE only)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",
        "fs": 50,
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,

        # windowing
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # loss
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        # outputs
        "plot_dir": "/content/A1_pairwise_A1viz",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            7: 'Frontal elevation of arms',
            8: 'Knees bending',
            10: 'Jogging',
            11: 'Running',
            12: 'Jump front & back'
        },

        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            10: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            11: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
        },

        "COUNT_TABLE": {
            6: {
                "subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20,
            },
            7: {
                "subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20,
            },
            8: {
                "subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21,
            },
            10: {
                "subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156,
            },
            11: {
                "subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172,
            },
            12: {
                "subject1": 20, "subject2": 22, "subject3": 21, "subject4": 21, "subject5": 20,
                "subject6": 21, "subject7": 19, "subject8": 20, "subject9": 20, "subject10": 20,
            },
        }
    }

    # -----------------------------
    # sanity: feature 동일성 체크
    # -----------------------------
    feats_ref = None
    for act_id, feats in CONFIG["ACT_FEATURE_MAP"].items():
        if feats_ref is None:
            feats_ref = tuple(feats)
        elif tuple(feats) != feats_ref:
            raise ValueError(f"[ERROR] ACT_FEATURE_MAP must be identical across activities. mismatch at act_id={act_id}")

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        print("[ERROR] dataset load failed")
        return

    _ensure_dir(CONFIG["plot_dir"])

    subjects = [f"subject{i}" for i in range(1, 11)]

    exp_dirs = [
        (7, 8),
        (8, 10),
    ]

    def build_labels(act_id):
        labels = []
        for s in subjects:
            if act_id not in CONFIG["COUNT_TABLE"]:
                continue
            if s not in CONFIG["COUNT_TABLE"][act_id]:
                continue
            labels.append((s, act_id, CONFIG["COUNT_TABLE"][act_id][s]))
        return labels

    # ✅ Reference axis limits from A1_PCA_train7_test8_subject5
    REF_SUBJ = "subject5"
    REF_XLIM, REF_YLIM = None, None

    final_blocks = []
    exp_idx = 0

    for train_act, test_act in exp_dirs:
        exp_idx += 1

        train_labels = build_labels(train_act)
        test_labels  = build_labels(test_act)
        if len(train_labels) == 0 or len(test_labels) == 0:
            print(f"[Skip] Missing COUNT_TABLE: train_act={train_act}, test_act={test_act}")
            continue

        set_strict_seed(CONFIG["seed"])

        train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
        test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
        if len(train_trials) == 0 or len(test_trials) == 0:
            print(f"[Skip] Empty trials: train_act={train_act}, test_act={test_act}")
            continue

        train_windows = trial_list_to_windows(
            train_trials,
            fs=CONFIG["fs"],
            win_sec=CONFIG["win_sec"],
            stride_sec=CONFIG["stride_sec"],
            drop_last=CONFIG["drop_last"]
        )
        if len(train_windows) == 0:
            print(f"[Skip] Empty train windows: train_act={train_act}")
            continue

        g = torch.Generator()
        g.manual_seed(CONFIG["seed"])
        train_loader = DataLoader(
            TrialDataset(train_windows),
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            collate_fn=collate_variable_length,
            generator=g,
            num_workers=0
        )

        input_ch = train_windows[0]["data"].shape[1]
        model = KAutoCountModel(
            input_ch=input_ch,
            hidden_dim=CONFIG["hidden_dim"],
            latent_dim=CONFIG["latent_dim"],
            K_max=CONFIG["K_max"]
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

        for epoch in range(CONFIG["epochs"]):
            train_one_epoch(model, train_loader, optimizer, CONFIG, device)
            scheduler.step()

        model.eval()

        train_name = CONFIG["TARGET_ACTIVITIES_MAP"][train_act]
        test_name  = CONFIG["TARGET_ACTIVITIES_MAP"][test_act]

        block_lines = []
        abs_errs = []
        per_subj = {}

        for item in test_trials:
            subj = item["subj"]
            x_np = item["data"]
            gt = float(item["count"])

            pred = predict_count_by_windowing(
                model,
                x_np=x_np,
                fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                device=device,
                tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = pred - gt
            ae = abs(diff)
            abs_errs.append(ae)
            per_subj[subj] = {"ae": ae, "pred": pred, "gt": gt, "x_np": x_np}

            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, _, _, aux = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)

            line = (f"{subj} | GT={gt:.0f} | Pred(win)={pred:.2f} | Diff={diff:+.2f} | "
                    f"k_hat(full)={k_hat:.2f} | phase_entropy(full)={ent:.3f}")
            block_lines.append(line)

        mae = float(np.mean(abs_errs)) if len(abs_errs) > 0 else float("nan")

        is_success_pair = (train_act == 7 and test_act == 8)
        if is_success_pair:
            pick_subj = sorted(per_subj.keys(), key=lambda s: per_subj[s]["ae"])[0]
        else:
            pick_subj = sorted(per_subj.keys(), key=lambda s: per_subj[s]["ae"], reverse=True)[0]

        # ------------------
        # ✅ REF axis set once: based on A1_PCA_train7_test8_subject5
        # ------------------
        if REF_XLIM is None and REF_YLIM is None and (train_act == 7 and test_act == 8):
            ref_train_item = next((it for it in train_trials if it["subj"] == REF_SUBJ), None)
            ref_test_item  = next((it for it in test_trials  if it["subj"] == REF_SUBJ), None)

            # fallback (only if subject5 not found)
            if ref_train_item is None or ref_test_item is None:
                ref_train_item = next((it for it in train_trials if it["subj"] == pick_subj), None)
                ref_test_item  = next((it for it in test_trials  if it["subj"] == pick_subj), None)

            if ref_train_item is not None and ref_test_item is not None:
                fs = CONFIG["fs"]
                xtr_ref = torch.tensor(ref_train_item["data"], dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
                xte_ref = torch.tensor(ref_test_item["data"],  dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)

                with torch.no_grad():
                    ztr_ref = model.encoder(xtr_ref).squeeze(0).detach().cpu().numpy()
                    zte_ref = model.encoder(xte_ref).squeeze(0).detach().cpu().numpy()

                REF_XLIM, REF_YLIM = _compute_ref_xlim_ylim_from_train_test_latents(ztr_ref, zte_ref, pad_ratio=0.05)
                print(f"[REF AXIS] Using A1_PCA_train7_test8_{REF_SUBJ} xlim={REF_XLIM}, ylim={REF_YLIM}")

        # ------------------
        # A1 visualization for picked subject
        # ------------------
        train_item = next((it for it in train_trials if it["subj"] == pick_subj), None)
        test_item  = next((it for it in test_trials  if it["subj"] == pick_subj), None)

        if train_item is not None and test_item is not None:
            fs = CONFIG["fs"]

            xtr = torch.tensor(train_item["data"], dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            xte = torch.tensor(test_item["data"],  dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)

            with torch.no_grad():
                ztr = model.encoder(xtr).squeeze(0).detach().cpu().numpy()
                zte = model.encoder(xte).squeeze(0).detach().cpu().numpy()

            ttr = np.arange(ztr.shape[0], dtype=np.float32) / float(fs)
            tte = np.arange(zte.shape[0], dtype=np.float32) / float(fs)

            pca_path = os.path.join(
                CONFIG["plot_dir"],
                f"A1_PCA_train{train_act}_test{test_act}_{pick_subj}.png"
            )

            plot_pca2d_train_test_overlay(
                train_z_tD=ztr, train_t_sec=ttr,
                test_z_tD=zte,  test_t_sec=tte,
                save_path=pca_path,
                dpi=600,
                max_points=2000,
                show_colorbar=True,
                tick_pad=0,
                point_size=10,
                alpha_test=0.9,
                alpha_train=0.35,
                train_lw=2.0,
                # ✅ apply fixed axis (subject5 기준)
                fixed_xlim=REF_XLIM,
                fixed_ylim=REF_YLIM
            )

            x_np = test_item["data"]
            gt_count = float(test_item["count"])
            pred_count, t_centers, rates, total_dur = predict_count_and_rates_by_windowing(
                model,
                x_np=x_np,
                fs=fs,
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                device=device,
                tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            gt_rate = gt_count / max(1e-6, total_dur)

            rate_path = os.path.join(
                CONFIG["plot_dir"],
                f"A1_Rate_train{train_act}_test{test_act}_{pick_subj}.png"
            )
            plot_predrate_vs_gtrate(
                t_sec=t_centers,
                pred_rate=rates,
                gt_rate=gt_rate,
                save_path=rate_path,
                dpi=600,
                tick_pad=0,
                smooth_sigma=1.0
            )

        final_blocks.append({
            "idx": exp_idx,
            "title": f"{train_name} -> {test_name}",
            "lines": block_lines,
            "mae": mae,
            "picked_subject_for_plot": pick_subj,
            "plot_dir": CONFIG["plot_dir"]
        })

    print("\n" + "=" * 110)
    print("A1 Pairwise Activity Transfer Results (SUCCESS + FAILURE only, windowing inference)")
    print("=" * 110)

    if len(final_blocks) == 0:
        print("[No results] COUNT_TABLE이 충분히 채워졌는지 확인해줘.")
        print("=" * 110)
        return

    for b in final_blocks:
        print(f"\n{b['idx']}. {b['title']}  |  MAE={b['mae']:.3f}")
        print(f"   -> A1 plots saved for: {b['picked_subject_for_plot']}  |  dir: {b['plot_dir']}")
        for ln in b["lines"]:
            print(ln)

    maes = [b["mae"] for b in final_blocks if np.isfinite(b["mae"])]
    if len(maes) > 0:
        print("\n" + "-" * 110)
        print(f"Overall Avg MAE={float(np.mean(maes)):.3f} | Std={float(np.std(maes)):.3f} | #experiments={len(maes)}")
        print("-" * 110)

    print("=" * 110)


if __name__ == "__main__":
    main()


[REF AXIS] Using A1_PCA_train7_test8_subject5 xlim=(-4.4326539268493645, 4.755654520034789), ylim=(-4.507941735863685, 3.118046556591987)

A1 Pairwise Activity Transfer Results (SUCCESS + FAILURE only, windowing inference)

1. Frontal elevation of arms -> Knees bending  |  MAE=2.275
   -> A1 plots saved for: subject5  |  dir: /content/A1_pairwise_A1viz
subject1 | GT=20 | Pred(win)=24.42 | Diff=+4.42 | k_hat(full)=1.00 | phase_entropy(full)=0.014
subject2 | GT=21 | Pred(win)=23.17 | Diff=+2.17 | k_hat(full)=1.01 | phase_entropy(full)=0.039
subject3 | GT=21 | Pred(win)=22.01 | Diff=+1.01 | k_hat(full)=1.01 | phase_entropy(full)=0.030
subject4 | GT=19 | Pred(win)=19.86 | Diff=+0.86 | k_hat(full)=1.01 | phase_entropy(full)=0.033
subject5 | GT=20 | Pred(win)=19.23 | Diff=-0.77 | k_hat(full)=1.01 | phase_entropy(full)=0.035
subject6 | GT=20 | Pred(win)=15.92 | Diff=-4.08 | k_hat(full)=1.01 | phase_entropy(full)=0.017
subject7 | GT=21 | Pred(win)=19.29 | Diff=-1.71 | k_hat(full)=1.02 | phase_

# for mm-Fit
- 위 1번 3번과 동일한 시각화

In [3]:
# =========================
# ✅ MM-Fit DROP-IN (Scenario-wise LOSO) — FULL CODE
#
# Goal:
# - Replace mHealth data pipeline with MM-Fit
# - Test ALL directed A->B pairs among 3 exercises => 3*2 = 6 scenarios
# - Keep visualization EXACTLY as-is (DO NOT TOUCH plot design)
#
# Protocol (same spirit as your mHealth scenario code):
# - For each scenario (A->B):
#   - TRAIN: single-activity windows from exercise A, subjects except held-out
#   - TEST (LOSO on A): held-out subject's ALL A trials -> fold MAE averaged across its A trials
#   - SCENARIO eval: build ONE mixed trial for held-out subject by concatenating:
#       (selected A trial) + (selected B trial)
#     then run windowing inference & do latent forward on mixed
#
# Notes:
# - MM-Fit subjects (participant) may have different #trials per exercise.
# - For LOSO(A) fold MAE: average over ALL A trials of held-out subject.
# - For mixed scenario: pick ONE deterministic trial per exercise per subject (the first by sort order).
# - If a subject lacks A or B trials, that fold is skipped for that scenario.
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns  # kept
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d  # kept
from matplotlib.patches import Circle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from matplotlib.collections import LineCollection


# ---------------------------------------------------------------------
# ✅ Save helpers
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)

def _safe_fname(s: str):
    s = str(s)
    for ch in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\n', '\t']:
        s = s.replace(ch, '_')
    s = s.strip().replace(' ', '_')
    while '__' in s:
        s = s.replace('__', '_')
    return s


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) MM-Fit Data Loading (meta + npz)
# ---------------------------------------------------------------------
def load_mmfit_meta(meta_csv_path: str,
                    target_exercises: set,
                    only_reps=None,
                    require_device: str = "sw_r"):
    """
    meta.csv 로딩.
    expected columns:
      participant, session, set_id, exercise, reps, fs, npz_path
    optional:
      device
    """
    if not os.path.exists(meta_csv_path):
        raise FileNotFoundError(f"[MM-Fit] meta csv not found: {meta_csv_path}")

    meta = pd.read_csv(meta_csv_path)

    need_cols = ["participant", "session", "set_id", "exercise", "reps", "fs", "npz_path"]
    for c in need_cols:
        if c not in meta.columns:
            raise ValueError(f"[MM-Fit] meta csv missing column: {c} (have={list(meta.columns)})")

    meta["exercise"] = meta["exercise"].astype(str).str.strip().str.lower()

    meta["session"] = meta["session"].astype(str).str.strip()

    # participant sanitize
    meta["participant"] = pd.to_numeric(meta["participant"], errors="coerce")
    meta = meta.dropna(subset=["participant"]).copy()
    meta["participant"] = meta["participant"].astype(int)

    # device filter if exists
    if "device" in meta.columns and require_device is not None:
        meta = meta[meta["device"].astype(str).str.strip() == require_device].copy()

    # exercise filter
    meta = meta[meta["exercise"].isin(set(target_exercises))].copy()

    # reps filter
    if only_reps is not None:
        meta = meta[meta["reps"].round().astype(int) == int(only_reps)].copy()

    # path exists
    meta["npz_path"] = meta["npz_path"].astype(str)
    ok = meta["npz_path"].apply(lambda p: os.path.exists(p))
    meta = meta[ok].copy()

    meta = meta.sort_values(["participant", "session", "exercise", "set_id"]).reset_index(drop=True)
    return meta


def _is_finite_np(x: np.ndarray) -> bool:
    return np.isfinite(x).all()


def _load_npz_X_and_fs(npz_path: str, fallback_fs: float):
    d = np.load(npz_path, allow_pickle=True)
    if "X" not in d.files:
        return None, None
    X = d["X"].astype(np.float32)  # (T,C)
    fs_npz = float(d["fs"]) if "fs" in d.files else float(fallback_fs)
    return X, fs_npz


def prepare_trial_list_mmfit(meta_rows: pd.DataFrame,
                            expected_fs: float,
                            skip_nonfinite: bool = True,
                            verbose_skip: bool = True):
    """
    - each row = 1 trial
    - count label = reps
    - per-trial z-score
    - skip nonfinite (NaN/Inf)
    """
    trial_list = []
    skipped = []  # (reason, npz_path)

    for _, r in meta_rows.iterrows():
        npz_path = str(r["npz_path"])
        reps = float(r["reps"])
        pid  = int(r["participant"])
        sid  = str(r["session"])
        set_id = int(r["set_id"])
        ex = str(r["exercise"])

        try:
            X, fs_npz = _load_npz_X_and_fs(npz_path, fallback_fs=float(r["fs"]))
        except Exception:
            skipped.append(("npz_load_fail", npz_path))
            continue

        if X is None:
            skipped.append(("missing_X", npz_path))
            continue

        if abs(float(fs_npz) - float(expected_fs)) > 1e-3:
            raise ValueError(
                f"[MM-Fit] fs mismatch for {npz_path}: fs_npz={fs_npz}, expected_fs={expected_fs}. "
                f"-> extractor 단계에서 TARGET_FS를 expected_fs로 통일하세요."
            )

        if skip_nonfinite and (not _is_finite_np(X)):
            skipped.append(("nonfinite_X", npz_path))
            continue

        # per-trial z-score
        mean = X.mean(axis=0)
        std  = X.std(axis=0)
        std  = np.where(std < 1e-6, 1e-6, std).astype(np.float32)
        norm_np = (X - mean) / std

        if skip_nonfinite and (not _is_finite_np(norm_np)):
            skipped.append(("nonfinite_after_norm", npz_path))
            continue

        trial_list.append({
            "data": norm_np,              # (T,C)
            "count": float(reps),         # trial total count
            "meta": f"subj{pid}_{sid}_set{set_id:02d}_{ex}",
            "participant": pid,
            "session": sid,
            "exercise": ex,
            "set_id": set_id,
            "npz_path": npz_path,
        })

    if verbose_skip and len(skipped) > 0:
        print(f"[MM-Fit] Skipped {len(skipped)} trials (nonfinite/load issues). Examples:")
        for i in range(min(5, len(skipped))):
            print("  -", skipped[i][0], ":", skipped[i][1])

    return trial_list, skipped


# ---------------------------------------------------------------------
# ✅ MM-Fit single-trial picker (deterministic) + mixed builder
# ---------------------------------------------------------------------
def _select_first_trial_row(meta_df: pd.DataFrame, participant: int, exercise: str):
    """
    deterministic: meta already sorted by (participant, session, exercise, set_id)
    -> pick the first row for (participant, exercise)
    """
    sub = meta_df[(meta_df["participant"] == int(participant)) & (meta_df["exercise"] == str(exercise))].copy()
    if len(sub) == 0:
        return None
    sub = sub.sort_values(["participant", "session", "exercise", "set_id"]).reset_index(drop=True)
    return sub.iloc[0]


def get_single_trial_raw_mmfit(meta_df: pd.DataFrame, participant: int, exercise: str, expected_fs: float):
    """
    Return raw X (no per-trial zscore) + reps count, picking first trial deterministically.
    """
    r = _select_first_trial_row(meta_df, participant, exercise)
    if r is None:
        return None

    npz_path = str(r["npz_path"])
    reps = float(r["reps"])
    fs_fallback = float(r["fs"])

    try:
        X, fs_npz = _load_npz_X_and_fs(npz_path, fallback_fs=fs_fallback)
    except Exception:
        return None

    if X is None:
        return None

    if abs(float(fs_npz) - float(expected_fs)) > 1e-3:
        raise ValueError(
            f"[MM-Fit] fs mismatch for {npz_path}: fs_npz={fs_npz}, expected_fs={expected_fs}."
        )

    if not _is_finite_np(X):
        return None

    # meta string
    pid = int(r["participant"])
    sid = str(r["session"])
    set_id = int(r["set_id"])
    ex = str(r["exercise"])
    meta = f"subj{pid}_{sid}_set{set_id:02d}_{ex}"

    return {
        "data_raw": X.astype(np.float32),   # raw (T,C)
        "count": float(reps),
        "meta": meta,
        "npz_path": npz_path,
        "session": sid,
        "set_id": set_id,
        "exercise": ex,
        "participant": pid,
    }


def build_mixed_ab_trial_mmfit(participant: int, exA: str, exB: str, config: dict, meta_df: pd.DataFrame):
    """
    Mixed trial = concat(raw A, raw B) -> z-score on concatenated signal
    Also returns normalized A, normalized B under the same (mix) stats, and boundary.
    """
    A = get_single_trial_raw_mmfit(meta_df, participant, exA, expected_fs=config["fs"])
    B = get_single_trial_raw_mmfit(meta_df, participant, exB, expected_fs=config["fs"])
    if A is None or B is None:
        return None

    xA_raw = A["data_raw"]
    xB_raw = B["data_raw"]
    boundary = int(xA_raw.shape[0])

    x_mix_raw = np.concatenate([xA_raw, xB_raw], axis=0).astype(np.float32)

    # global z-score on mixed
    mean = x_mix_raw.mean(axis=0)
    std  = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA = (xA_raw - mean) / std
    xB = (xB_raw - mean) / std

    gtA = float(A["count"])
    gtB = float(B["count"])
    gt_total = gtA + gtB

    scen_name = f"{exA}__TO__{exB}"

    return {
        "data": x_mix,                         # (Tmix,C) normalized by mix stats
        "count": float(gt_total),
        "meta": f"subject{participant}__{scen_name}",
        "boundary": boundary,
        "meta_detail": {
            "subj": f"subject{participant}",
            "actA_name": exA, "actB_name": exB,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A": xA,
        "data_B": xB,
        "T_A": int(xA.shape[0]),
        "T_B": int(xB.shape[0]),
    }


# ---------------------------------------------------------------------
# 2.5) Windowing (UNCHANGED)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# Helpers for scenario-avg plots (UNCHANGED)
# ---------------------------------------------------------------------
def get_window_rate_curve(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]

    if T <= win_len:
        pred_count, rates = predict_count_by_windowing(
            model, x_np, fs, win_sec, stride_sec, device, tau=tau, batch_size=batch_size
        )
        t_cent = np.array([0.5 * (T / float(fs))], dtype=np.float32)
        return t_cent, rates.astype(np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())
    rates = np.concatenate(rates, axis=0).astype(np.float32)

    centers = np.array([(st + 0.5 * win_len) / float(fs) for st in starts], dtype=np.float32)
    return centers, rates


def resample_1d(y, new_len):
    y = np.asarray(y, dtype=np.float32)
    if y.size == 0:
        return np.zeros((new_len,), dtype=np.float32)
    if y.size == 1:
        return np.full((new_len,), float(y[0]), dtype=np.float32)
    x_old = np.linspace(0.0, 1.0, num=y.size, dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    return np.interp(x_new, x_old, y).astype(np.float32)


def resample_2d(Y, new_len):
    Y = np.asarray(Y, dtype=np.float32)
    if Y.shape[0] == 0:
        return np.zeros((new_len, Y.shape[1]), dtype=np.float32)
    if Y.shape[0] == 1:
        return np.repeat(Y, repeats=new_len, axis=0).astype(np.float32)
    x_old = np.linspace(0.0, 1.0, num=Y.shape[0], dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    out = []
    for d in range(Y.shape[1]):
        out.append(np.interp(x_new, x_old, Y[:, d]))
    return np.stack(out, axis=1).astype(np.float32)


def global_pca2(Z_all):
    Z_all = np.asarray(Z_all, dtype=np.float32)
    mu = Z_all.mean(axis=0, keepdims=True)
    Zc = Z_all - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    V2 = Vt[:2].astype(np.float32)
    return mu.squeeze(0).astype(np.float32), V2


# ---------------------------------------------------------------------
# ✅ PLOTTER (UNCHANGED DESIGN)
# ---------------------------------------------------------------------
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable  # ✅ keep duplicated import as you had

def plot_scenario_mean_std_two_plots(
    save_dir,
    scen_name,
    win_t, win_mean, win_std,
    lat_xy_mean, lat_xy_std,
    n_subjects,
    dpi=600,
):
    _ensure_dir(save_dir if save_dir is not None else "scenario_viz")
    out_dir = save_dir if save_dir is not None else "scenario_viz"
    scen_tag = _safe_fname(scen_name)

    # -------------------------
    # (1) Window-level Rep rate mean±std
    # -------------------------
    fig1 = plt.figure(figsize=(26.0, 9.2))
    ax = fig1.gca()

    # subtle A/B shading (keep but slightly lighter)
    ax.axvspan(0.0, 1.0, alpha=0.035, color="k", lw=0)
    ax.axvspan(1.0, 2.0, alpha=0.022, color="k", lw=0)

    ax.plot(
        win_t, win_mean,
        linewidth=7.8,
        solid_capstyle="round",
        label="Window rep rate (mean)"
    )
    ax.fill_between(
        win_t,
        win_mean - win_std, win_mean + win_std,
        alpha=0.16,
        linewidth=0,
        label="±1 std"
    )

    ax.axvline(
        1.0,
        linestyle="--",
        linewidth=4.2,
        label="Boundary"
    )

    ax.set_xlabel("Normalized time (A:0–1, B:1–2)", fontsize=54, labelpad=10)
    ax.set_ylabel("Rep rate (reps/s)", fontsize=54, labelpad=10)
    ax.tick_params(axis="both", labelsize=46, width=1.6, length=7)

    for sp in ax.spines.values():
        sp.set_linewidth(1.6)

    ax.legend(
        fontsize=42,
        frameon=True,
        loc="upper left",
        borderpad=0.55,
        labelspacing=0.35,
        handlelength=2.4
    )

    fig1.tight_layout()
    fig1.savefig(os.path.join(out_dir, f"{scen_tag}__1_win_rep_rate.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig1)

    # -------------------------
    # (2) Latent trajectory PCA2 (gradient line)
    # -------------------------
    fig2 = plt.figure(figsize=(10.0, 10.0))
    ax = fig2.gca()

    lat_xy_mean = np.asarray(lat_xy_mean, dtype=np.float32)
    T = lat_xy_mean.shape[0]
    t_lat = np.linspace(0.0, 2.0, num=T, dtype=np.float32)

    pts = lat_xy_mean.reshape(-1, 1, 2)
    segs = np.concatenate([pts[:-1], pts[1:]], axis=1)

    lc = LineCollection(
        segs,
        array=t_lat[:-1],
        linewidth=2.6,
        alpha=0.95
    )
    ax.add_collection(lc)

    ax.set_xlabel("PC1", fontsize=25, labelpad=10)
    ax.set_ylabel("PC2", fontsize=25, labelpad=10)
    ax.tick_params(axis="both", labelsize=20, width=1.4, length=6)

    for sp in ax.spines.values():
        sp.set_linewidth(1.4)

    ax.set_aspect("equal", adjustable="box")
    ax.set_anchor("C")

    x = lat_xy_mean[:, 0]; y = lat_xy_mean[:, 1]
    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    cx = 0.5 * (xmin + xmax); cy = 0.5 * (ymin + ymax)
    rx = 0.5 * (xmax - xmin); ry = 0.5 * (ymax - ymin)
    r = max(rx, ry) * 1.12 + 1e-6
    ax.set_xlim(cx - r, cx + r)
    ax.set_ylim(cy - r, cy + r)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="4.5%", pad=0.10)
    cbar = fig2.colorbar(lc, cax=cax)
    cbar.set_label("Normalized time (0→2)", fontsize=22, labelpad=10)
    cbar.ax.tick_params(labelsize=18, width=1.2, length=5)

    fig2.tight_layout()
    fig2.savefig(os.path.join(out_dir, f"{scen_tag}__2_latent_pca2.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)


# ---------------------------------------------------------------------
# Dataset / Collate (UNCHANGED)
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# ---------------------------------------------------------------------
# Model (UNCHANGED)
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)
        z = z.transpose(1, 2)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# Loss utils (UNCHANGED)
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# Train (UNCHANGED)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# Main (Scenario-wise LOSO) — MM-Fit
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,

        # ✅ MM-Fit meta csv
        "mmfit_meta_csv": "/content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials/meta_sw_r_dumbbell_rows_lunges_pushups.csv",

        # ✅ exercises (3) => all directed pairs => 6 scenarios
        "MMFIT_EXERCISES": ["pushups", "lunges", "dumbbell_rows"],

        # optional filters
        "MMFIT_ONLY_REPS": None,
        "MMFIT_REQUIRE_DEVICE": "sw_r",

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 100,   # ✅ MM-Fit sw_r acc/gyro is typically 100Hz in your pipeline

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # load meta
    meta_all = load_mmfit_meta(
        meta_csv_path=CONFIG["mmfit_meta_csv"],
        target_exercises=set(CONFIG["MMFIT_EXERCISES"]),
        only_reps=CONFIG["MMFIT_ONLY_REPS"],
        require_device=CONFIG["MMFIT_REQUIRE_DEVICE"],
    )
    if len(meta_all) == 0:
        print("[MM-Fit] meta is empty after filtering. Check meta_csv / filters.")
        return

    # build all directed scenarios (A->B), A != B => 6
    ex_list = list(CONFIG["MMFIT_EXERCISES"])
    MIX_SCENARIOS = []
    for A in ex_list:
        for B in ex_list:
            if A == B:
                continue
            MIX_SCENARIOS.append((f"{A}-{B}", A, B))

    # output dir
    VIZ_DIR = "scenario_viz"
    os.makedirs(VIZ_DIR, exist_ok=True)

    # resample lengths (fixed for averaging)
    NW = 60     # window-rate points per segment (A,B)
    NL = 200    # latent points per segment (A,B)

    # For each scenario (A->B)
    for scen_idx, (scen_name, exA, exB) in enumerate(MIX_SCENARIOS, start=1):
        TRAIN_EX = exA

        print("\n" + "=" * 100)
        print(f"{scen_idx}. 시나리오: {scen_name}   (TRAIN={TRAIN_EX})")
        print("=" * 100)

        # subjects = participants who have at least one A trial
        subjA = sorted(meta_all[meta_all["exercise"] == TRAIN_EX]["participant"].unique().tolist())

        # for scenario, need A and B available (at least 1 trial each)
        subjB = set(meta_all[meta_all["exercise"] == exB]["participant"].unique().tolist())
        subjects = [s for s in subjA if s in subjB]

        if len(subjects) == 0:
            print(f"[Skip scenario] No participants have both A={exA} and B={exB}.")
            continue

        loso_results = []
        scenario_records = []
        scenario_diffs = []
        scenario_diffs_A = []
        scenario_diffs_B = []

        # ✅ buffers for TWO plots only
        buf_win_concat = []     # list of (2*NW,)
        buf_z_concat = []       # list of (Tmix,D)
        per_subject_latent = [] # list of {"zA","zB"}

        for fold_idx, test_pid in enumerate(subjects, start=1):
            set_strict_seed(CONFIG["seed"])

            # ----------------------------------------
            # TRAIN trials: ALL A trials from other subjects
            # TEST(A) trials: ALL A trials from held-out subject
            # ----------------------------------------
            train_meta = meta_all[(meta_all["exercise"] == TRAIN_EX) & (meta_all["participant"] != int(test_pid))].copy()
            test_metaA = meta_all[(meta_all["exercise"] == TRAIN_EX) & (meta_all["participant"] == int(test_pid))].copy()

            train_trials, _ = prepare_trial_list_mmfit(train_meta, expected_fs=CONFIG["fs"], skip_nonfinite=True, verbose_skip=False)
            test_trialsA, _ = prepare_trial_list_mmfit(test_metaA, expected_fs=CONFIG["fs"], skip_nonfinite=True, verbose_skip=False)

            if len(train_trials) == 0 or len(test_trialsA) == 0:
                continue

            train_data = trial_list_to_windows(
                train_trials, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )
            if len(train_data) == 0:
                continue

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            # model init
            input_ch = train_data[0]["data"].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for _ in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval()

            # ----------------------------------------
            # (A) LOSO performance on TRAIN_EX (A) — average over ALL A trials of held-out subject
            # ----------------------------------------
            abs_errs = []
            res_str = ""

            for itemA in test_trialsA:
                x_np = itemA["data"]
                gtA = float(itemA["count"])
                predA, _ = predict_count_by_windowing(
                    model, x_np=x_np,
                    fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    device=device, tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )
                abs_errs.append(abs(predA - gtA))
                res_str += f"[Pred(win): {predA:.1f} / GT: {gtA:.0f}]"

            fold_mae = float(np.mean(abs_errs)) if len(abs_errs) > 0 else None
            if fold_mae is not None:
                loso_results.append(fold_mae)

            # ----------------------------------------
            # (B) Scenario eval on mixed A->B for held-out subject (one deterministic A trial + one deterministic B trial)
            # ----------------------------------------
            mixed_item = build_mixed_ab_trial_mmfit(
                participant=int(test_pid),
                exA=exA, exB=exB,
                config=CONFIG,
                meta_df=meta_all
            )
            if mixed_item is None:
                # still print fold line for LOSO(A) if you want; keep minimal
                continue

            x_mix = mixed_item["data"]
            boundary = mixed_item["boundary"]
            gt_total = float(mixed_item["count"])
            md = mixed_item["meta_detail"]

            pred_total, _ = predict_count_by_windowing(
                model, x_np=x_mix,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = float(pred_total - gt_total)
            mae_total = float(abs(diff))
            scenario_diffs.append(diff)

            # A/B split
            xA = mixed_item["data_A"]
            xB = mixed_item["data_B"]
            gtA = float(md["gtA"])
            gtB = float(md["gtB"])

            pred_A, _ = predict_count_by_windowing(
                model, x_np=xA,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            pred_B, _ = predict_count_by_windowing(
                model, x_np=xB,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff_A = float(pred_A - gtA)
            diff_B = float(pred_B - gtB)
            mae_A = float(abs(diff_A))
            mae_B = float(abs(diff_B))
            scenario_diffs_A.append(diff_A)
            scenario_diffs_B.append(diff_B)

            # latent forward on mixed
            x_tensor = torch.tensor(x_mix, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, z_m, _, aux_m = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p_m = aux_m["phase_p"].squeeze(0).detach().cpu().numpy()  # (Tmix,K)
            k_hat_m = float(aux_m["k_hat"].item())
            ent_m = compute_phase_entropy_mean(phase_p_m)

            z_td = z_m.squeeze(0).detach().cpu().numpy()  # (Tmix,D)

            scenario_records.append({
                "subj": f"subject{test_pid}",
                "line": (
                    f"[Scenario] {scen_name} | subject{test_pid} | "
                    f"{md['actA_name']}({md['gtA']:.0f}) -> {md['actB_name']}({md['gtB']:.0f}) | "
                    f"GT_total={gt_total:.0f} | Pred_total(win)={pred_total:.2f} | Diff_total={diff:+.2f} | "
                    f"MAE_total={mae_total:.2f} | "
                    f"Pred_A={pred_A:.2f} (Diff_A={diff_A:+.2f}, MAE_A={mae_A:.2f}) | "
                    f"Pred_B={pred_B:.2f} (Diff_B={diff_B:+.2f}, MAE_B={mae_B:.2f}) | "
                    f"k_hat(full)={k_hat_m:.2f} | ent(full)={ent_m:.3f} | boundary={boundary}"
                )
            })

            # (1) window-level rep_rate curves: A/B separately -> resample -> concat
            _, rA_w = get_window_rate_curve(
                model, xA, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rB_w = get_window_rate_curve(
                model, xB, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            rA_rs = resample_1d(rA_w, NW)
            rB_rs = resample_1d(rB_w, NW)
            buf_win_concat.append(np.concatenate([rA_rs, rB_rs], axis=0))

            # (2) latent: store raw z split for global PCA later
            zA = z_td[:boundary, :]
            zB = z_td[boundary:, :]
            buf_z_concat.append(z_td)
            per_subject_latent.append({"zA": zA, "zB": zB})

            # optional: print fold LOSO(A) line in same style (kept minimal)
            if fold_mae is not None:
                print(f"Fold {fold_idx:2d} | Test: subject{test_pid} | MAE(A)={fold_mae:.2f} | {res_str}")

        # -------------------------
        # Print summaries (same style as your original)
        # -------------------------
        print("-" * 100)
        if len(loso_results) > 0:
            print("TRain 결과 ->")
            print("-" * 100)
            print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
            print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
            print("-" * 100)
        else:
            print("TRain 결과 -> (no folds computed)")
            print("-" * 100)

        # per subject scenario lines (stable ordering)
        subj2line = {r["subj"]: r["line"] for r in scenario_records}
        for pid in subjects:
            key = f"subject{pid}"
            if key in subj2line:
                print(subj2line[key])

        print("\n" + "-" * 100)
        print(f"시나리오{scen_idx} 결과")
        print("-" * 100)

        if len(scenario_diffs) == 0:
            print(f"[{scen_name}] No samples (all skipped). Check availability of A/B trials per subject.")
        else:
            diffs = np.array(scenario_diffs, dtype=np.float32)
            mae = float(np.mean(np.abs(diffs)))
            rmse = float(np.sqrt(np.mean(diffs ** 2)))
            print(f"[TOTAL] [{scen_name}] N={len(diffs):2d} | MAE={mae:.3f} | RMSE={rmse:.3f} | mean(diff)={diffs.mean():+.3f} | std(diff)={diffs.std():.3f}")

            diffsA = np.array(scenario_diffs_A, dtype=np.float32)
            maeA = float(np.mean(np.abs(diffsA)))
            rmseA = float(np.sqrt(np.mean(diffsA ** 2)))
            print(f"[A]     [{scen_name}] N={len(diffsA):2d} | MAE={maeA:.3f} | RMSE={rmseA:.3f} | mean(diff)={diffsA.mean():+.3f} | std(diff)={diffsA.std():.3f}")

            diffsB = np.array(scenario_diffs_B, dtype=np.float32)
            maeB = float(np.mean(np.abs(diffsB)))
            rmseB = float(np.sqrt(np.mean(diffsB ** 2)))
            print(f"[B]     [{scen_name}] N={len(diffsB):2d} | MAE={maeB:.3f} | RMSE={rmseB:.3f} | mean(diff)={diffsB.mean():+.3f} | std(diff)={diffsB.std():.3f}")

        print("-" * 100)

        # -------------------------
        # ✅ TWO-PLOT scenario-avg visualization (UNCHANGED)
        # -------------------------
        n_ok = len(buf_win_concat)
        if n_ok == 0:
            print("[Viz Skip] No scenario samples collected for averaging.")
            continue

        # (1) window-level mean±std on normalized axis [0,2]
        W = np.stack(buf_win_concat, axis=0)  # (N, 2*NW)
        win_mean = W.mean(axis=0)
        win_std  = W.std(axis=0)
        win_t = np.linspace(0.0, 2.0, num=2*NW, dtype=np.float32)

        # (2) latent: global PCA basis from all z
        Z_all = np.concatenate(buf_z_concat, axis=0)  # (sumT, D)
        mu, V2 = global_pca2(Z_all)

        # project each subject, resample A/B to NL then concat -> (2*NL,2)
        lat_list = []
        for d in per_subject_latent:
            zA = d["zA"]; zB = d["zB"]
            pcA = (zA - mu) @ V2.T
            pcB = (zB - mu) @ V2.T
            pcA_rs = resample_2d(pcA, NL)
            pcB_rs = resample_2d(pcB, NL)
            lat_list.append(np.concatenate([pcA_rs, pcB_rs], axis=0))

        LAT = np.stack(lat_list, axis=0)  # (N,2*NL,2)
        lat_xy_mean = LAT.mean(axis=0)
        lat_xy_std  = LAT.std(axis=0)

        plot_scenario_mean_std_two_plots(
            save_dir=VIZ_DIR,
            scen_name=scen_name,
            win_t=win_t, win_mean=win_mean, win_std=win_std,
            lat_xy_mean=lat_xy_mean, lat_xy_std=lat_xy_std,
            n_subjects=n_ok,
            dpi=600
        )


if __name__ == "__main__":
    main()


Device: cuda

1. 시나리오: pushups-lunges   (TRAIN=pushups)
Fold  1 | Test: subject0 | MAE(A)=1.38 | [Pred(win): 7.3 / GT: 10][Pred(win): 10.0 / GT: 10][Pred(win): 9.8 / GT: 10][Pred(win): 9.5 / GT: 10][Pred(win): 8.4 / GT: 10][Pred(win): 7.5 / GT: 10][Pred(win): 9.7 / GT: 10][Pred(win): 11.9 / GT: 10][Pred(win): 10.5 / GT: 10][Pred(win): 7.5 / GT: 10][Pred(win): 8.2 / GT: 10][Pred(win): 7.8 / GT: 10][Pred(win): 11.4 / GT: 10][Pred(win): 9.3 / GT: 10][Pred(win): 9.0 / GT: 10][Pred(win): 9.1 / GT: 10][Pred(win): 7.6 / GT: 10][Pred(win): 8.2 / GT: 10]
Fold  2 | Test: subject1 | MAE(A)=1.71 | [Pred(win): 8.6 / GT: 10][Pred(win): 8.8 / GT: 10][Pred(win): 8.9 / GT: 10][Pred(win): 6.9 / GT: 10][Pred(win): 9.6 / GT: 10][Pred(win): 11.2 / GT: 10][Pred(win): 8.9 / GT: 10][Pred(win): 12.3 / GT: 10][Pred(win): 12.0 / GT: 10][Pred(win): 8.3 / GT: 10][Pred(win): 9.3 / GT: 10][Pred(win): 13.9 / GT: 10][Pred(win): 9.8 / GT: 10][Pred(win): 8.8 / GT: 10][Pred(win): 13.4 / GT: 10][Pred(win): 5.5 / GT: 10][P

KeyboardInterrupt: 

In [1]:
# ============================================================
# A1 Pairwise Activity Transfer — mm-Fit (ALL combinations)  ✅
#
# - Train: activity A windows (window-proxy supervision)
# - Test : activity B full trial -> windowing inference
#
# ✅ Visualization UNCHANGED:
#   (1) Train-PCA basis -> project Test (overlay)  [fixed xlim/ylim once]
#   (2) pred_rate(t) vs gt_rate overlay
#
# ✅ Runs ALL ordered pairs (A -> B, A!=B)
#
# ------------------------------------------------------------
# Assumption:
# - You already have mm-Fit extracted as:
#     MMFIT_OUT_DIR/
#       npz/*.npz
#       meta_*.csv   (or meta.csv)
# - Each meta row corresponds to one trial and contains:
#     activity/exercise label, participant/subject id, repetition count, npz file path (or filename)
# - Each npz contains time-series array:
#     key in {"X","x","data","imu","signal"} with shape (T,C)
#
# If your meta/npz schema differs, adjust only the small "schema mapping" part.
# ============================================================

import os
import glob
import random
import numpy as np
import pandas as pd

from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


# ---------------------------------------------------------------------
# 0) Small helpers (dir / downsample / PCA)
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)


def _downsample_by_index(X, t, max_points=2000):
    X = np.asarray(X)
    t = np.asarray(t)
    n = len(t)
    if n <= max_points:
        return X, t
    idx = np.linspace(0, n - 1, max_points).astype(np.int64)
    return X[idx], t[idx]


def _fit_pca_basis(z_tD, n_comp=2, eps=1e-6):
    """
    z_tD: (T,D)
    returns: mu(D,), sig(D,), W(D,n_comp)
    """
    Z = np.asarray(z_tD, dtype=np.float32)
    mu = Z.mean(axis=0)
    sig = Z.std(axis=0) + eps
    Zs = (Z - mu) / sig
    C = (Zs.T @ Zs) / max(1, (Zs.shape[0] - 1))
    U, S, Vt = np.linalg.svd(C, full_matrices=False)
    W = U[:, :n_comp]
    return mu.astype(np.float32), sig.astype(np.float32), W.astype(np.float32)


def _pca_project(z_tD, mu, sig, W):
    Z = np.asarray(z_tD, dtype=np.float32)
    Zs = (Z - mu) / sig
    return (Zs @ W).astype(np.float32)


def _compute_ref_xlim_ylim_from_train_test_latents(train_z_tD, test_z_tD, pad_ratio=0.05,
                                                   q_low=1.0, q_high=99.0):
    """
    ✅ 0-centered symmetric axis
    - PCA basis fit on TRAIN
    - project TRAIN/TEST
    - robust range by percentiles (q_low~q_high)
    - make limits symmetric around 0
    """
    mu, sig, W = _fit_pca_basis(train_z_tD, n_comp=2)
    P_tr = _pca_project(train_z_tD, mu, sig, W)
    P_te = _pca_project(test_z_tD,  mu, sig, W)

    all_x = np.concatenate([P_tr[:, 0], P_te[:, 0]], axis=0)
    all_y = np.concatenate([P_tr[:, 1], P_te[:, 1]], axis=0)

    # robust min/max (outlier 완화). outlier 신경 안 쓰면 q_low=0, q_high=100로 둬도 됨
    x_lo, x_hi = np.percentile(all_x, [q_low, q_high])
    y_lo, y_hi = np.percentile(all_y, [q_low, q_high])

    Lx = max(abs(float(x_lo)), abs(float(x_hi)))
    Ly = max(abs(float(y_lo)), abs(float(y_hi)))

    # padding (range 비례)
    Lx = Lx * (1.0 + float(pad_ratio))
    Ly = Ly * (1.0 + float(pad_ratio))

    return (-Lx, +Lx), (-Ly, +Ly)



# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) mm-Fit Loading (meta.csv + npz)
# ---------------------------------------------------------------------
def _find_first_existing(paths):
    for p in paths:
        if p is not None and os.path.isfile(p):
            return p
    return None


def _infer_col(df_cols, candidates):
    cols_lower = {c.lower(): c for c in df_cols}
    for cand in candidates:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]
    # partial match
    for c in df_cols:
        cl = c.lower()
        for cand in candidates:
            if cand.lower() in cl:
                return c
    return None


def load_mmfit_meta(MMFIT_OUT_DIR):
    """
    meta_*.csv 또는 meta.csv를 찾아서 concat.
    """
    meta_paths = sorted(glob.glob(os.path.join(MMFIT_OUT_DIR, "meta*.csv")))
    if len(meta_paths) == 0:
        raise FileNotFoundError(f"[ERROR] No meta*.csv found in: {MMFIT_OUT_DIR}")

    dfs = []
    for p in meta_paths:
        try:
            df = pd.read_csv(p)
            if len(df) > 0:
                df["__meta_src__"] = os.path.basename(p)
                dfs.append(df)
        except Exception as e:
            print(f"[Warning] meta read failed: {p} -> {e}")

    if len(dfs) == 0:
        raise RuntimeError("[ERROR] meta csv files exist but all failed/empty.")

    meta = pd.concat(dfs, axis=0, ignore_index=True)
    return meta


def load_mmfit_trials_from_meta(MMFIT_OUT_DIR, meta_df, fs_default=100, npz_subdir="npz"):
    """
    Returns:
      activities: sorted list of unique activity names
      participants: sorted list of participant IDs (string)
      trials: list[dict] with keys:
        data (T,C) float32  [trial-wise zscore normalized]
        count float
        meta str
        subj str
        act str
        fs int
        raw_path str
    """
    npz_dir = os.path.join(MMFIT_OUT_DIR, npz_subdir)
    if not os.path.isdir(npz_dir):
        raise FileNotFoundError(f"[ERROR] npz dir not found: {npz_dir}")

    # ---- robust column mapping ----
    col_act = _infer_col(meta_df.columns, ["activity", "exercise", "label", "act", "activity_name", "class"])
    col_sub = _infer_col(meta_df.columns, ["participant", "subject", "user", "pid", "subj", "person"])
    col_cnt = _infer_col(meta_df.columns, ["count", "reps", "rep", "repetition", "rep_count", "total_reps", "n_reps"])
    col_fs  = _infer_col(meta_df.columns, ["fs", "freq", "sampling_rate", "hz"])
    col_pth = _infer_col(meta_df.columns, ["npz_path", "path", "file", "filename", "npz_file", "npz"])

    if col_act is None or col_sub is None or col_cnt is None:
        raise ValueError(
            "[ERROR] meta.csv must contain activity/participant/count columns.\n"
            f"  - found: act={col_act}, subj={col_sub}, count={col_cnt}, fs={col_fs}, path={col_pth}\n"
            "  - please rename columns or add mapping in load_mmfit_trials_from_meta()."
        )

    trials = []
    for i, row in meta_df.iterrows():
        act = str(row[col_act])
        subj = str(row[col_sub])
        cnt = float(row[col_cnt])

        fs = fs_default
        if col_fs is not None:
            try:
                fs = int(row[col_fs])
            except:
                fs = fs_default

        # resolve npz path
        npz_path = None
        if col_pth is not None and pd.notna(row[col_pth]):
            p = str(row[col_pth])
            if os.path.isabs(p) and os.path.isfile(p):
                npz_path = p
            else:
                # relative or filename
                cand = os.path.join(MMFIT_OUT_DIR, p)
                if os.path.isfile(cand):
                    npz_path = cand
                else:
                    cand2 = os.path.join(npz_dir, os.path.basename(p))
                    if os.path.isfile(cand2):
                        npz_path = cand2

        if npz_path is None:
            # fallback: try to find by pattern (subj + act)
            # NOTE: your extractor naming convention may differ.
            patt1 = os.path.join(npz_dir, f"*{subj}*{act}*.npz")
            hits = sorted(glob.glob(patt1))
            if len(hits) > 0:
                npz_path = hits[0]

        if npz_path is None or not os.path.isfile(npz_path):
            # skip silently (meta row without file)
            continue

        try:
            npz = np.load(npz_path, allow_pickle=True)
            X = None
            for k in ["X", "x", "data", "imu", "signal"]:
                if k in npz.files:
                    X = npz[k]
                    break
            if X is None:
                # try first array-like
                for k in npz.files:
                    arr = npz[k]
                    if isinstance(arr, np.ndarray) and arr.ndim == 2:
                        X = arr
                        break
            if X is None:
                continue

            X = np.asarray(X, dtype=np.float32)
            if X.ndim != 2:
                continue

            # trial-wise z-score
            mean = X.mean(axis=0)
            std = X.std(axis=0) + 1e-6
            Xn = (X - mean) / std

            trials.append({
                "data": Xn,                 # (T,C)
                "count": float(cnt),         # total reps
                "meta": f"{subj}_{act}__row{i}",
                "subj": subj,
                "act": act,
                "fs": fs,
                "raw_path": npz_path
            })
        except Exception as e:
            print(f"[Warning] npz load failed: {npz_path} -> {e}")
            continue

    if len(trials) == 0:
        raise RuntimeError("[ERROR] No trials loaded. Check meta columns / npz paths / npz contents.")

    activities = sorted(list({t["act"] for t in trials}))
    participants = sorted(list({t["subj"] for t in trials}))
    return activities, participants, trials


def filter_trials_by_activity(trials, act_name):
    return [t for t in trials if t["act"] == act_name]


# ---------------------------------------------------------------------
# 2.5) Windowing (same)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return float(pred_count)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N,C,win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count)


def predict_count_and_rates_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        rate = float(rate_hat.item())
        pred_count = rate * total_dur
        t_centers = np.array([0.5 * total_dur], dtype=np.float32)
        rates = np.array([rate], dtype=np.float32)
        return float(pred_count), t_centers, rates, float(total_dur)

    starts = list(range(0, T - win_len + 1, stride))
    centers = np.array([st + 0.5 * win_len for st in starts], dtype=np.float32) / float(fs)

    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N,C,win_len)

    rates_list = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates_list.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates_list, axis=0).astype(np.float32)  # (N,)
    pred_count = float(rates.mean() * total_dur)
    return float(pred_count), centers, rates, float(total_dur)


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model (same)
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B,D,T)
        z = z.transpose(1, 2)      # (B,T,D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B,D,T)
        x_hat = self.net(zt)       # (B,C,T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        micro_rate_t = amp_t             # (B,T)

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,          # (B,T,K)
            "rep_rate_t": rep_rate_t,    # (B,T)
            "k_hat": k_hat,              # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils (same)
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)
    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean()


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)  # (T,K)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# 4.5) A1 Plots (UNCHANGED)
# ---------------------------------------------------------------------
def plot_pca2d_train_test_overlay(
    train_z_tD, train_t_sec,
    test_z_tD, test_t_sec,
    save_path,
    dpi=600,
    max_points=2000,
    show_colorbar=True,
    tick_pad=0,
    point_size=10,
    alpha_test=0.9,
    alpha_train=0.35,
    train_lw=2.0,
    fixed_xlim=None,
    fixed_ylim=None
):
    train_z_tD = np.asarray(train_z_tD, dtype=np.float32)
    train_t_sec = np.asarray(train_t_sec, dtype=np.float32)
    test_z_tD  = np.asarray(test_z_tD,  dtype=np.float32)
    test_t_sec = np.asarray(test_t_sec, dtype=np.float32)

    train_z_tD, train_t_sec = _downsample_by_index(train_z_tD, train_t_sec, max_points)
    test_z_tD,  test_t_sec  = _downsample_by_index(test_z_tD,  test_t_sec,  max_points)

    test_c = (test_t_sec - test_t_sec.min()) / max(1e-6, (test_t_sec.max() - test_t_sec.min()))

    mu, sig, W = _fit_pca_basis(train_z_tD, n_comp=2)
    P_tr = _pca_project(train_z_tD, mu, sig, W)
    P_te = _pca_project(test_z_tD,  mu, sig, W)

    fig, ax = plt.subplots(1, 1, figsize=(6.0, 6.0))

    ax.scatter(P_tr[:, 0], P_tr[:, 1], s=point_size, alpha=0.25, color="0.6", label="Train (ref)")
    sc = ax.scatter(P_te[:, 0], P_te[:, 1], c=test_c, s=point_size, alpha=alpha_test, cmap="viridis", label="Test (proj)")

    ax.set_xlabel("PC1", fontsize=16)
    ax.set_ylabel("PC2", fontsize=16)
    ax.tick_params(axis="both", which="major", labelsize=14, pad=tick_pad)
    ax.grid(False)

    if fixed_xlim is not None and fixed_ylim is not None:
        ax.set_xlim(float(fixed_xlim[0]), float(fixed_xlim[1]))
        ax.set_ylim(float(fixed_ylim[0]), float(fixed_ylim[1]))
    else:
        all_x = np.concatenate([P_tr[:, 0], P_te[:, 0]], axis=0)
        all_y = np.concatenate([P_tr[:, 1], P_te[:, 1]], axis=0)

        x_min, x_max = float(all_x.min()), float(all_x.max())
        y_min, y_max = float(all_y.min()), float(all_y.max())

        x_pad = 0.05 * max(1e-6, (x_max - x_min))
        y_pad = 0.05 * max(1e-6, (y_max - y_min))

        ax.set_xlim(x_min - x_pad, x_max + x_pad)
        ax.set_ylim(y_min - y_pad, y_max + y_pad)

    if show_colorbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="4.5%", pad=0.10)
        cbar = fig.colorbar(sc, cax=cax)
        cbar.set_label("Normalized time (0→1)", fontsize=14, labelpad=10)
        cbar.ax.tick_params(labelsize=12, pad=tick_pad)
        cbar.outline.set_linewidth(1.2)

    _ensure_dir(os.path.dirname(save_path))
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


def plot_predrate_vs_gtrate(
    t_sec,
    pred_rate,
    gt_rate,
    save_path,
    dpi=600,
    tick_pad=0,
    smooth_sigma=1.0,
    lw_pred=4.0,
    lw_gt=4.0,
    legend_loc="best"
):
    t_sec = np.asarray(t_sec, dtype=np.float32)
    pred_rate = np.asarray(pred_rate, dtype=np.float32)

    if smooth_sigma is not None and smooth_sigma > 0:
        pred_plot = gaussian_filter1d(pred_rate, sigma=float(smooth_sigma))
    else:
        pred_plot = pred_rate

    fig, ax = plt.subplots(1, 1, figsize=(8.5, 5.5))

    ax.plot(
        t_sec, pred_plot,
        linewidth=lw_pred,
        label="Pred rate (window-level)"
    )

    ax.axhline(
        y=float(gt_rate),
        linestyle="--",
        linewidth=lw_gt,
        label="GT rate (trial-level)"
    )

    ax.set_xlabel("Time (s)", fontsize=34)
    ax.set_ylabel("Repetition rate (reps/s)", fontsize=34)
    ax.tick_params(axis="both", which="major", labelsize=24, pad=tick_pad)
    ax.grid(False)
    ax.legend(loc=legend_loc, frameon=False, fontsize=18)

    _ensure_dir(os.path.dirname(save_path))
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# 5) Train (same)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, _, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()


# ---------------------------------------------------------------------
# 6) Main — mm-Fit ALL combinations
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,

        # ✅ EDIT THIS
        "MMFIT_OUT_DIR": "/content/drive/MyDrive/Colab Notebooks/HAR_data/mmfit_imu_3ex_trials",

        # If your extractor put npz under different folder name:
        "npz_subdir": "npz",

        # sampling rate (if meta lacks fs col)
        "fs": 100,

        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,

        # windowing
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # loss
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        # outputs
        "plot_dir": "/content/A1_mmfit_pairwise_A1viz",

        # plot policy
        # - each pair selects ONE "picked trial" for A1 plots.
        #   "worst" is usually more diagnostic for failure-like alignment issues.
        "pick_policy": "worst",   # "worst" | "best" | "median"
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _ensure_dir(CONFIG["plot_dir"])

    # -----------------------------
    # Load meta + trials
    # -----------------------------
    meta = load_mmfit_meta(CONFIG["MMFIT_OUT_DIR"])
    activities, participants, all_trials = load_mmfit_trials_from_meta(
        MMFIT_OUT_DIR=CONFIG["MMFIT_OUT_DIR"],
        meta_df=meta,
        fs_default=CONFIG["fs"],
        npz_subdir=CONFIG["npz_subdir"]
    )

    print(f"[mm-Fit] Loaded trials={len(all_trials)} | activities={activities} | participants={participants}")

    # -----------------------------
    # Build all ordered pairs (A->B, A!=B)
    # -----------------------------
    exp_dirs = []
    for a in activities:
        for b in activities:
            if a == b:
                continue
            exp_dirs.append((a, b))

    # ✅ Reference axis limits (compute once)
    REF_XLIM, REF_YLIM = None, None
    REF_PAIR = exp_dirs[0] if len(exp_dirs) > 0 else None
    REF_SUBJ = participants[0] if len(participants) > 0 else None

    final_blocks = []
    exp_idx = 0

    for train_act, test_act in exp_dirs:
        exp_idx += 1
        set_strict_seed(CONFIG["seed"])

        train_trials = filter_trials_by_activity(all_trials, train_act)
        test_trials  = filter_trials_by_activity(all_trials, test_act)

        if len(train_trials) == 0 or len(test_trials) == 0:
            print(f"[Skip] Empty trials: train_act={train_act}, test_act={test_act}")
            continue

        # fs policy:
        # - training uses CONFIG["fs"] (assumed consistent after extraction)
        # - if your trials have mixed fs, you should resample first (not done here to keep logic unchanged).
        fs = CONFIG["fs"]

        # windowize train
        train_windows = trial_list_to_windows(
            train_trials,
            fs=fs,
            win_sec=CONFIG["win_sec"],
            stride_sec=CONFIG["stride_sec"],
            drop_last=CONFIG["drop_last"]
        )
        if len(train_windows) == 0:
            print(f"[Skip] Empty train windows: train_act={train_act}")
            continue

        # check channel consistency
        input_ch = train_windows[0]["data"].shape[1]
        ok = True
        for t in train_trials + test_trials:
            if t["data"].shape[1] != input_ch:
                ok = False
                break
        if not ok:
            print(f"[Skip] Channel mismatch in pair {train_act}->{test_act}. (C differs across trials)")
            continue

        g = torch.Generator()
        g.manual_seed(CONFIG["seed"])
        train_loader = DataLoader(
            TrialDataset(train_windows),
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            collate_fn=collate_variable_length,
            generator=g,
            num_workers=0
        )

        model = KAutoCountModel(
            input_ch=input_ch,
            hidden_dim=CONFIG["hidden_dim"],
            latent_dim=CONFIG["latent_dim"],
            K_max=CONFIG["K_max"]
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

        for epoch in range(CONFIG["epochs"]):
            train_one_epoch(model, train_loader, optimizer, CONFIG, device)
            scheduler.step()

        model.eval()

        # evaluate on test trials
        block_lines = []
        abs_errs = []
        per_trial = {}

        for item in test_trials:
            subj = item["subj"]
            x_np = item["data"]
            gt = float(item["count"])

            pred = predict_count_by_windowing(
                model,
                x_np=x_np,
                fs=fs,
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                device=device,
                tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = pred - gt
            ae = abs(diff)
            abs_errs.append(ae)

            # full-trial k_hat/entropy
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, _, _, aux = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)

            trial_key = item["meta"]
            per_trial[trial_key] = {"ae": ae, "pred": pred, "gt": gt, "item": item}

            line = (f"{trial_key} | subj={subj} | GT={gt:.0f} | Pred(win)={pred:.2f} | Diff={diff:+.2f} | "
                    f"k_hat(full)={k_hat:.2f} | phase_entropy(full)={ent:.3f}")
            block_lines.append(line)

        mae = float(np.mean(abs_errs)) if len(abs_errs) > 0 else float("nan")

        # pick one trial for A1 plots
        keys = list(per_trial.keys())
        if len(keys) == 0:
            continue

        policy = CONFIG.get("pick_policy", "worst")
        if policy == "best":
            pick_key = sorted(keys, key=lambda k: per_trial[k]["ae"])[0]
        elif policy == "median":
            ks = sorted(keys, key=lambda k: per_trial[k]["ae"])
            pick_key = ks[len(ks)//2]
        else:  # "worst"
            pick_key = sorted(keys, key=lambda k: per_trial[k]["ae"], reverse=True)[0]

        pick_item = per_trial[pick_key]["item"]
        pick_subj = pick_item["subj"]

        # ------------------
        # ✅ REF axis set once (first pair + first participant if exists)
        # ------------------
        if REF_XLIM is None and REF_YLIM is None and REF_PAIR is not None:
            ref_train_act, ref_test_act = REF_PAIR
            if (train_act == ref_train_act) and (test_act == ref_test_act):
                ref_train_trials = filter_trials_by_activity(all_trials, ref_train_act)
                ref_test_trials  = filter_trials_by_activity(all_trials, ref_test_act)

                # try first participant
                ref_train_item = next((it for it in ref_train_trials if it["subj"] == REF_SUBJ), None)
                ref_test_item  = next((it for it in ref_test_trials  if it["subj"] == REF_SUBJ), None)

                # fallback
                if ref_train_item is None:
                    ref_train_item = ref_train_trials[0]
                if ref_test_item is None:
                    ref_test_item = ref_test_trials[0]

                xtr_ref = torch.tensor(ref_train_item["data"], dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
                xte_ref = torch.tensor(ref_test_item["data"],  dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
                with torch.no_grad():
                    ztr_ref = model.encoder(xtr_ref).squeeze(0).detach().cpu().numpy()
                    zte_ref = model.encoder(xte_ref).squeeze(0).detach().cpu().numpy()

                REF_XLIM, REF_YLIM = _compute_ref_xlim_ylim_from_train_test_latents(ztr_ref, zte_ref, pad_ratio=0.05)
                print(f"[REF AXIS] Using REF pair {ref_train_act}->{ref_test_act}, subj={REF_SUBJ} xlim={REF_XLIM}, ylim={REF_YLIM}")

        # ------------------
        # A1 visualization for picked trial
        # ------------------
        # need one train trial (same subject if possible) + picked test trial
        train_item = next((it for it in train_trials if it["subj"] == pick_subj), None)
        if train_item is None:
            train_item = train_trials[0]
        test_item = pick_item

        xtr = torch.tensor(train_item["data"], dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        xte = torch.tensor(test_item["data"],  dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            ztr = model.encoder(xtr).squeeze(0).detach().cpu().numpy()
            zte = model.encoder(xte).squeeze(0).detach().cpu().numpy()

        ttr = np.arange(ztr.shape[0], dtype=np.float32) / float(fs)
        tte = np.arange(zte.shape[0], dtype=np.float32) / float(fs)

        safe_train = str(train_act).replace(" ", "_").replace("/", "_")
        safe_test  = str(test_act).replace(" ", "_").replace("/", "_")
        safe_subj  = str(pick_subj).replace(" ", "_").replace("/", "_")

        pca_path = os.path.join(
            CONFIG["plot_dir"],
            f"A1_PCA_train{safe_train}_test{safe_test}_{safe_subj}.png"
        )
        plot_pca2d_train_test_overlay(
            train_z_tD=ztr, train_t_sec=ttr,
            test_z_tD=zte,  test_t_sec=tte,
            save_path=pca_path,
            dpi=600,
            max_points=2000,
            show_colorbar=True,
            tick_pad=0,
            point_size=10,
            alpha_test=0.9,
            alpha_train=0.35,
            train_lw=2.0,
            fixed_xlim=REF_XLIM,
            fixed_ylim=REF_YLIM
        )

        x_np = test_item["data"]
        gt_count = float(test_item["count"])
        pred_count, t_centers, rates, total_dur = predict_count_and_rates_by_windowing(
            model,
            x_np=x_np,
            fs=fs,
            win_sec=CONFIG["win_sec"],
            stride_sec=CONFIG["stride_sec"],
            device=device,
            tau=CONFIG.get("tau", 1.0),
            batch_size=CONFIG.get("batch_size", 64)
        )
        gt_rate = gt_count / max(1e-6, total_dur)

        rate_path = os.path.join(
            CONFIG["plot_dir"],
            f"A1_Rate_train{safe_train}_test{safe_test}_{safe_subj}.png"
        )
        plot_predrate_vs_gtrate(
            t_sec=t_centers,
            pred_rate=rates,
            gt_rate=gt_rate,
            save_path=rate_path,
            dpi=600,
            tick_pad=0,
            smooth_sigma=1.0
        )

        final_blocks.append({
            "idx": exp_idx,
            "title": f"{train_act} -> {test_act}",
            "lines": block_lines,
            "mae": mae,
            "picked_subject_for_plot": pick_subj,
            "picked_trial_meta": pick_key,
            "plot_dir": CONFIG["plot_dir"]
        })

        print(f"[Done] {exp_idx:03d} {train_act} -> {test_act} | MAE={mae:.3f} | pick={pick_key}")

    print("\n" + "=" * 110)
    print("A1 Pairwise Activity Transfer Results (mm-Fit ALL combinations, windowing inference)")
    print("=" * 110)

    if len(final_blocks) == 0:
        print("[No results] meta/npz 로딩 또는 activity 필터링을 확인해줘.")
        print("=" * 110)
        return

    # summary print
    for b in final_blocks:
        print(f"\n{b['idx']}. {b['title']}  |  MAE={b['mae']:.3f}")
        print(f"   -> A1 plots saved for: subj={b['picked_subject_for_plot']} | trial={b['picked_trial_meta']}")
        print(f"   -> dir: {b['plot_dir']}")
        # 너무 길면 주석 처리 가능
        for ln in b["lines"][:20]:
            print(ln)
        if len(b["lines"]) > 20:
            print(f"... ({len(b['lines'])-20} more lines)")

    maes = [b["mae"] for b in final_blocks if np.isfinite(b["mae"])]
    if len(maes) > 0:
        print("\n" + "-" * 110)
        print(f"Overall Avg MAE={float(np.mean(maes)):.3f} | Std={float(np.std(maes)):.3f} | #pairs={len(maes)}")
        print("-" * 110)

    print("=" * 110)


if __name__ == "__main__":
    main()


[mm-Fit] Loaded trials=189 | activities=['dumbbell_rows', 'lunges', 'pushups'] | participants=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
[REF AXIS] Using REF pair dumbbell_rows->lunges, subj=0 xlim=(-7.368919578552251, 7.368919578552251), ylim=(-11.295863071918486, 11.295863071918486)
[Done] 001 dumbbell_rows -> lunges | MAE=6.118 | pick=8_lunges__row174
[Done] 002 dumbbell_rows -> pushups | MAE=1.781 | pick=8_pushups__row173
[Done] 003 lunges -> dumbbell_rows | MAE=3.848 | pick=4_dumbbell_rows__row122
[Done] 004 lunges -> pushups | MAE=4.226 | pick=6_pushups__row155
[Done] 005 pushups -> dumbbell_rows | MAE=1.734 | pick=0_dumbbell_rows__row76
[Done] 006 pushups -> lunges | MAE=6.285 | pick=0_lunges__row92

A1 Pairwise Activity Transfer Results (mm-Fit ALL combinations, windowing inference)

1. dumbbell_rows -> lunges  |  MAE=6.118
   -> A1 plots saved for: subj=8 | trial=8_lunges__row174
   -> dir: /content/A1_mmfit_pairwise_A1viz
2_lunges__row2 | subj=2 | GT=10 | Pred(win)=16