In [None]:
# ============================================================
# Count-only K-auto (Multi-event) + Windowing (LOSO)
# + Qualitative 3-Panel Figure:
#   (a) Raw ankle acc (3-axis, demean)
#   (b) rep_rate(t) + peak markers
#       - find_peaks로 후보 넉넉히
#       - prominence 큰 순서
#       - min_gap 제약 greedy
#       - 정확히 topK=GT개 선택 (부족하면 gap relax)
#   (c) 3D trajectory: latent manifold (PCA2 of z(t)) × rep_rate(t)
#       - trajectory 중심 (line + light scatter)
# ============================================================

import os
import glob
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

from scipy.signal import find_peaks, peak_prominences
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split(".")[0]
        try:
            subj_id_num = int("".join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except Exception:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df["activity_id"] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=["activity_id"])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    """
    trial item:
      - data: normalized (T,C)  (model input)
      - raw_selected: raw (T,C)  (for plotting)
    """
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            mean = raw_np.mean(axis=0, keepdims=True)
            std = raw_np.std(axis=0, keepdims=True) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                "data": norm_np.astype(np.float32),
                "raw_selected": raw_np.astype(np.float32),
                "count": float(gt_count),
                "meta": f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 3) Window-level supervision (TRAIN) + windowing inference (TEST)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# 4) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item["data"], dtype=torch.float32).transpose(0, 1)  # (C,T)
        count = torch.tensor(item["count"], dtype=torch.float32)
        return data, count, item["meta"]


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# ---------------------------------------------------------------------
# 5) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)        # (B,D,T)
        z = z.transpose(1, 2)  # (B,T,D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        micro_rate_t = amp_t
        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 6) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# ---------------------------------------------------------------------
# 7) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        "loss", "loss_rate", "loss_recon", "loss_smooth", "loss_phase_ent", "loss_effk",
        "mae_count"
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)
    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, _, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats["loss"] += loss.item()
        stats["loss_rate"] += loss_rate.item()
        stats["loss_recon"] += loss_recon.item()
        stats["loss_smooth"] += loss_smooth.item()
        stats["loss_phase_ent"] += loss_phase_ent.item()
        stats["loss_effk"] += loss_effk.item()
        stats["mae_count"] += torch.abs(count_hat - y_count).mean().item()

    n = max(1, len(loader))
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 8) Visualization helpers
# ---------------------------------------------------------------------
def _moving_average(x, w):
    if w <= 1:
        return np.asarray(x, dtype=np.float32)
    x = np.asarray(x, dtype=np.float32)
    pad = w // 2
    xp = np.pad(x, (pad, pad), mode="edge")
    k = np.ones(w, dtype=np.float32) / float(w)
    y = np.convolve(xp, k, mode="valid")
    return y.astype(np.float32)


def _pca_project_np(Z, n_components=2):
    """
    Z: (T, D)
    return:
      P: (T, n_components)
    """
    Z = np.asarray(Z, dtype=np.float32)
    mu = Z.mean(axis=0, keepdims=True)
    Zc = Z - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    comps = Vt[:n_components].T  # (D, n_components)
    P = Zc @ comps
    return P.astype(np.float32)


def _select_topk_by_prominence_greedy(peaks, prominences, topk, min_gap):
    if peaks is None or len(peaks) == 0:
        return np.array([], dtype=np.int64)

    peaks = np.asarray(peaks, dtype=np.int64)
    prominences = np.asarray(prominences, dtype=np.float32)

    order = np.argsort(-prominences)  # desc
    picked = []
    for idx in order:
        p = int(peaks[idx])
        if all(abs(p - q) >= min_gap for q in picked):
            picked.append(p)
        if len(picked) >= topk:
            break

    return np.array(sorted(picked), dtype=np.int64)


def pick_peaks_exact_topk(rr_s, fs, topk, min_gap_sec,
                          base_distance_sec=0.15,
                          base_prom_q=0.40):
    """
    1) find_peaks로 후보 넉넉히
    2) prominence 계산
    3) prominence desc + greedy(min_gap)로 topk 선택
    4) 부족하면 min_gap 완화
    5) 그래도 부족하면 gap 무시로 topk 보장
    """
    rr_s = np.asarray(rr_s, dtype=np.float32).reshape(-1)
    base_dist = max(1, int(round(base_distance_sec * fs)))

    cand, _ = find_peaks(rr_s, distance=base_dist)
    if len(cand) == 0:
        return np.array([], dtype=np.int64)

    prom = peak_prominences(rr_s, cand)[0].astype(np.float32)

    # (optional) 아주 작은 prominence 제거 (하지만 topk 유지 가능할 때만)
    if len(prom) >= 10:
        thr = float(np.quantile(prom, base_prom_q))
        keep = prom >= thr
        cand2, prom2 = cand[keep], prom[keep]
        if len(cand2) >= topk:
            cand, prom = cand2, prom2

    min_gap = max(1, int(round(min_gap_sec * fs)))
    picked = _select_topk_by_prominence_greedy(cand, prom, topk=topk, min_gap=min_gap)

    if len(picked) < topk:
        # gap relax
        for relax in [0.85, 0.70, 0.55, 0.40, 0.25, 0.10]:
            mg = max(1, int(round(min_gap * relax)))
            picked = _select_topk_by_prominence_greedy(cand, prom, topk=topk, min_gap=mg)
            if len(picked) >= topk:
                break

    if len(picked) < topk:
        # 마지막 보장: gap 무시하고 prominence 큰 것부터 채움
        order = np.argsort(-prom)
        picked_set = set(picked.tolist())
        for idx in order:
            picked_set.add(int(cand[idx]))
            if len(picked_set) >= topk:
                break
        picked = np.array(sorted(list(picked_set))[:topk], dtype=np.int64)

    return picked


def plot_raw_rep_and_3d_traj(
    raw_ankle_acc_xyz,  # (T,3) raw
    rep_rate_t,         # (T,)
    Z_t,                # (T,D)
    fs,
    title,
    out_png,
    gt_count,
    pred_count,
    smooth_w=11,
    min_gap_sec=2.25,
    max_points_3d=1200,
    # bigger 3D (c)
    fig_w=15.5,
    fig_h=12.0,
    h_ratios=(1.0, 1.0, 3.2),
    cb_shrink=0.95,
    cb_pad=0.03,
):
    raw = np.asarray(raw_ankle_acc_xyz, dtype=np.float32)
    rr = np.asarray(rep_rate_t, dtype=np.float32).reshape(-1)
    Z = np.asarray(Z_t, dtype=np.float32)

    T = rr.shape[0]
    t = np.arange(T, dtype=np.float32) / float(fs)

    # (a) demean raw
    raw_demean = raw - raw.mean(axis=0, keepdims=True)

    # (b) smooth rep_rate
    rr_s = _moving_average(rr, smooth_w)

    # (b) peaks: exact topK = GT
    topk = int(round(float(gt_count)))
    peaks = pick_peaks_exact_topk(
        rr_s=rr_s,
        fs=fs,
        topk=topk,
        min_gap_sec=min_gap_sec,
        base_distance_sec=0.15,
        base_prom_q=0.40,
    )

    # (c) 3D: PCA2(z(t)) + rep_rate(t)
    P2 = _pca_project_np(Z, n_components=2)

    # downsample for trajectory-centric view
    stride = max(1, int(np.ceil(T / float(max_points_3d))))
    idx_ds = np.arange(0, T, stride, dtype=np.int64)

    P2_ds = P2[idx_ds]
    rr_ds = rr_s[idx_ds]
    t_ds = t[idx_ds]

    P2_pk = P2[peaks] if len(peaks) > 0 else np.zeros((0, 2), dtype=np.float32)
    rr_pk = rr_s[peaks] if len(peaks) > 0 else np.zeros((0,), dtype=np.float32)

    # ---- figure layout ----
    fig = plt.figure(figsize=(fig_w, fig_h))
    gs = fig.add_gridspec(3, 1, height_ratios=list(h_ratios), hspace=0.28)

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0], sharex=ax1)
    ax3 = fig.add_subplot(gs[2, 0], projection="3d")

    # (a)
    ax1.plot(t, raw_demean[:, 0], lw=1.0, label="acc_ankle_x (demean)")
    ax1.plot(t, raw_demean[:, 1], lw=1.0, label="acc_ankle_y (demean)")
    ax1.plot(t, raw_demean[:, 2], lw=1.0, label="acc_ankle_z (demean)")
    ax1.set_title("(a) Raw ankle accelerometer (demean, 3-axis)")
    ax1.set_ylabel("Acceleration (demeaned)")
    ax1.grid(True, alpha=0.25)
    ax1.legend(loc="upper right")

    # (b)
    ax2.plot(t, rr_s, lw=1.4, label="rep_rate(t) (smoothed)")
    if len(peaks) > 0:
        ax2.scatter(t[peaks], rr_s[peaks],
                    marker="*", s=220, c="red",
                    edgecolors="k", linewidths=0.8,
                    label=f"topK peaks (K=GT={topk})")
    ax2.set_title(f"(b) rep_rate(t) with topK=GT peak markers (min_gap≈{min_gap_sec:.2f}s)")
    ax2.set_xlabel("Time (s)")
    ax2.set_ylabel("Reps / second")
    ax2.grid(True, alpha=0.25)
    ax2.legend(loc="upper right")
    ax2.text(
        0.01, 0.96,
        f"GT={gt_count:.0f} | Pred={pred_count:.2f} | K(peaks)={len(peaks)}",
        transform=ax2.transAxes,
        va="top", ha="left",
        bbox=dict(boxstyle="round", alpha=0.12)
    )

    # (c) trajectory-centered: line + light scatter colored by time
    ax3.plot(P2_ds[:, 0], P2_ds[:, 1], rr_ds, lw=1.1, alpha=0.85)
    sc = ax3.scatter(P2_ds[:, 0], P2_ds[:, 1], rr_ds, c=t_ds, s=10, alpha=0.28)

    if len(peaks) > 0:
        ax3.scatter(P2_pk[:, 0], P2_pk[:, 1], rr_pk,
                    marker="*", s=180, c="red",
                    edgecolors="k", linewidths=0.8)

    ax3.set_title("(c) 3D coupling: latent trajectory (PCA-2D) × rep_rate(t)")
    ax3.set_xlabel("PC1")
    ax3.set_ylabel("PC2")
    ax3.set_zlabel("rep_rate(t)")
    ax3.view_init(elev=22, azim=-62)
    ax3.grid(True, alpha=0.25)

    cb = fig.colorbar(sc, ax=ax3, pad=cb_pad, shrink=cb_shrink)
    cb.set_label("Time (s)")

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    plt.savefig(out_png, dpi=300)
    plt.close()

    print(f"[Saved] {out_png}")
    print(f"[Peaks] topK={topk} selected (K={len(peaks)}) | min_gap_sec={min_gap_sec}")
    return peaks


# ---------------------------------------------------------------------
# 9) Main (LOSO)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            "acc_chest_x", "acc_chest_y", "acc_chest_z",
            "ecg_1", "ecg_2",
            "acc_ankle_x", "acc_ankle_y", "acc_ankle_z",
            "gyro_ankle_x", "gyro_ankle_y", "gyro_ankle_z",
            "mag_ankle_x", "mag_ankle_y", "mag_ankle_z",
            "acc_arm_x", "acc_arm_y", "acc_arm_z",
            "gyro_arm_x", "gyro_arm_y", "gyro_arm_z",
            "mag_arm_x", "mag_arm_y", "mag_arm_z",
            "activity_id"
        ],

        "TARGET_ACTIVITIES_MAP": {6: "Waist bends forward"},

        "ACT_FEATURE_MAP": {
            6: [
                "acc_chest_x", "acc_chest_y", "acc_chest_z",
                "acc_ankle_x", "acc_ankle_y", "acc_ankle_z",
                "gyro_ankle_x", "gyro_ankle_y", "gyro_ankle_z",
                "acc_arm_x", "acc_arm_y", "acc_arm_z",
                "gyro_arm_x", "gyro_arm_y", "gyro_arm_z"
            ],
        },

        "ALL_LABELS": [
            ("subject1", 6, 21), ("subject2", 6, 19), ("subject3", 6, 21),
            ("subject4", 6, 20), ("subject5", 6, 20), ("subject6", 6, 20),
            ("subject7", 6, 20), ("subject8", 6, 21), ("subject9", 6, 21),
            ("subject10", 6, 20),
        ],

        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        "tau": 1.0,

        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        # Visualization
        "VIZ_ENABLE": True,
        "VIZ_TARGET_SUBJ": "subject1",
        "VIZ_OUT_DIR": "./viz_raw_rep_3d",
        "VIZ_SMOOTH_W": 11,
        "VIZ_MIN_GAP_SEC": 2.25,
        "VIZ_MAX_POINTS_3D": 1200,
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]
    loso_results = []

    print("\n" + "=" * 70)
    print("LOSO subject shift (TRAIN: window-proxy / TEST: windowing inference)")
    print("=" * 70)

    for fold_idx, test_subj in enumerate(subjects):
        set_strict_seed(CONFIG["seed"])

        train_labels = [x for x in CONFIG["ALL_LABELS"] if x[0] != test_subj]
        test_labels = [x for x in CONFIG["ALL_LABELS"] if x[0] == test_subj]

        train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
        test_trials = prepare_trial_list(test_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

        if not test_trials:
            print(f"[Skip] Fold {fold_idx + 1}: {test_subj} has no data.")
            continue

        train_data = trial_list_to_windows(
            train_trials,
            fs=CONFIG["fs"],
            win_sec=CONFIG["win_sec"],
            stride_sec=CONFIG["stride_sec"],
            drop_last=CONFIG["drop_last"]
        )
        test_data = test_trials

        g = torch.Generator()
        g.manual_seed(CONFIG["seed"])

        train_loader = DataLoader(
            TrialDataset(train_data),
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            collate_fn=collate_variable_length,
            generator=g,
            num_workers=0
        )

        input_ch = train_data[0]["data"].shape[1]
        model = KAutoCountModel(
            input_ch=input_ch,
            hidden_dim=CONFIG["hidden_dim"],
            latent_dim=CONFIG["latent_dim"],
            K_max=CONFIG["K_max"]
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

        for _epoch in range(CONFIG["epochs"]):
            _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
            scheduler.step()

        model.eval()

        fold_mae = 0.0
        fold_res_str = ""
        viz_payload = None

        for item in test_data:
            x_np = item["data"]
            raw_selected = item["raw_selected"]

            count_pred_win, _ = predict_count_by_windowing(
                model,
                x_np=x_np,
                fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                device=device,
                tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            count_gt = float(item["count"])
            fold_mae += abs(count_pred_win - count_gt)
            fold_res_str += f"[Pred(win): {count_pred_win:.1f} / GT: {count_gt:.0f}]"

            # full forward for rep_rate(t) and z(t)
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, z, _, aux = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            z_np = z.squeeze(0).detach().cpu().numpy()

            feats = CONFIG["ACT_FEATURE_MAP"][6]
            ix = feats.index("acc_ankle_x")
            iy = feats.index("acc_ankle_y")
            iz = feats.index("acc_ankle_z")
            raw_ankle_xyz = raw_selected[:, [ix, iy, iz]]

            viz_payload = {
                "raw_ankle_xyz": raw_ankle_xyz,
                "rep_rate": rep_rate,
                "z_np": z_np,
                "gt": count_gt,
                "pred": float(count_pred_win),
                "meta": item["meta"]
            }

        fold_mae /= len(test_data)
        loso_results.append(fold_mae)
        print(f"Fold {fold_idx + 1:2d} | Test: {test_subj} | MAE: {fold_mae:.2f} | {fold_res_str}")

        if CONFIG["VIZ_ENABLE"] and (test_subj == CONFIG["VIZ_TARGET_SUBJ"]) and (viz_payload is not None):
            out_dir = CONFIG["VIZ_OUT_DIR"]
            os.makedirs(out_dir, exist_ok=True)
            out_png = os.path.join(out_dir, f"qual_raw_rep_3dtraj_topKGT_{test_subj}.png")

            plot_raw_rep_and_3d_traj(
                raw_ankle_acc_xyz=viz_payload["raw_ankle_xyz"],
                rep_rate_t=viz_payload["rep_rate"],
                Z_t=viz_payload["z_np"],
                fs=CONFIG["fs"],
                title=f"Qualitative example ({viz_payload['meta']})",
                out_png=out_png,
                gt_count=viz_payload["gt"],
                pred_count=viz_payload["pred"],
                smooth_w=CONFIG["VIZ_SMOOTH_W"],
                min_gap_sec=CONFIG["VIZ_MIN_GAP_SEC"],
                max_points_3d=CONFIG["VIZ_MAX_POINTS_3D"],
                # 3D 크게
                fig_w=15.5,
                fig_h=12.0,
                h_ratios=(1.0, 1.0, 3.2),
                cb_shrink=0.95,
                cb_pad=0.03,
            )

    print("=" * 70)
    print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
    print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
    print("=" * 70)


if __name__ == "__main__":
    main()
