In [1]:
# =========================
# Count-only K-auto (Multi-event) + Windowing version
#
# ✅ Windowing added:
# - TRAIN: trial -> sliding windows (window-level count = trial-average rate * window duration)
# - TEST : trial 그대로 두고, windowing inference로 window rate 평균 -> 전체 count 예측
# - k_hat / entropy / rep_rate / phase heatmap은 (표현학습 확인용) full-trial 1회 forward로 기록
#
# ✅ Requested changes applied:
# (1) "Phase(k_hat) heatmap" uses EXACT style/design of plot_phase_heatmap_and_dominant (provided code)
# (2) PCA2D plot: ❌ REMOVED (as requested)
# (3) All plots saved as PNG (dpi=600), no plt.show()
#
# ✅ Added in this request:
# - (3) Raw signal + predicted rep_rate overlay (short segment) plot added
#
# ✅ Updated in this request:
# - Use EXACT titles (no extra activity info inside the title):
#   1) Phase Usage Heatmap for K-auto Decomposition
#   2) 3D PCA Projection of the Latent Trajectory
#   3) Raw Signal and Predicted rep_rate Overlay (Short Segment)
# - REMOVE mean plot code (time-normalized averaging) and its calls.
#
# ✅ Plot tweaks in this request:
# 1) Phase heatmap colorbar: match plot height, cleaner sizing, larger tick labels
# 2) PCA3D colorbar: no overlap with PC3 label + better view/aspect to reveal shape
# 3) Raw overlay: slightly different colors, smaller legend, label text updates
# =========================

import os
import glob
import random
import re
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 0) IO helpers
# ---------------------------------------------------------------------
def _safe_filename(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^\w\-_\. ]", "_", s)
    s = s.strip().replace(" ", "_")
    return s[:200] if len(s) > 200 else s


def _ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)
    return p


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,              # (T, C)
                'count': float(gt_count),      # trial total count
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) ✅ Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    """
    TRAIN 전용: trial -> sliding windows 확장
    window 라벨은 trial-level 평균 rate로부터 생성:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    TEST 전용: trial -> sliding windows inference -> window rate 평균 -> total count
    x_np: (T,C) numpy (이미 정규화된 상태)
    return: pred_count(float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    # short trial -> 1회 forward
    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample (via effective K on time-avg phase)
    - rep_rate(t) = micro_rate(t) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p  # (B,T,K)

        micro_rate_t = amp_t  # (B,T)

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# ---------------------------------------------------------------------
# 5) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


# -----------------------------
# PCA utils (no sklearn)
# -----------------------------
def _fit_pca_basis(X, n_comp=3):
    X = np.asarray(X, dtype=np.float32)
    mu = X.mean(axis=0, keepdims=True)
    Xc = X - mu
    _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
    W = Vt[:n_comp].T
    return mu.squeeze(0), W


def _pca_project(X, mu, W):
    X = np.asarray(X, dtype=np.float32)
    return (X - mu[None, :]) @ W


# ---------------------------------------------------------------------
# ✅ (1) Phase heatmap + PNG saving  (cbar aligned + larger tick labels)
# ---------------------------------------------------------------------
def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000,
    save_path=None,
    dpi=600
):
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    fig = plt.figure(figsize=(12.5, 5))
    cmap = sns.color_palette("magma", as_cmap=True)

    ax0 = fig.add_subplot(1, 1, 1)
    im = ax0.imshow(
        phase_ds.T,                     # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K],
        cmap=cmap
    )
    if title:
        ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=26)
    ax0.set_xlabel("Time (sec)", fontsize=26)
    ax0.tick_params(axis="both", labelsize=18)

    # ✅ colorbar: same height as axes (clean) + larger ticks
    divider = make_axes_locatable(ax0)
    cax = divider.append_axes("right", size="2.6%", pad=0.15)
    cbar = fig.colorbar(im, cax=cax)
    cbar.set_label(r"$\hat{p}_t(k)$", fontsize=26, labelpad=10)
    cbar.ax.tick_params(labelsize=18)

    plt.tight_layout()

    if save_path is not None:
        _ensure_dir(os.path.dirname(save_path))
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# ✅ (2) PCA3D + PNG saving  (cbar no-overlap + better view/aspect)
# ---------------------------------------------------------------------
def plot_pca3d_z(z_np, fs, title="PCA3D of z(t)", max_T=2000, save_path=None, dpi=600):
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

    z_np = np.asarray(z_np, dtype=np.float32)
    z_ds, idx = downsample_time_axis(z_np, max_T=max_T)  # (T',D)
    t_sec = idx / float(fs)

    mu, W = _fit_pca_basis(z_ds, n_comp=3)
    Z3 = _pca_project(z_ds, mu, W)  # (T',3)

    # ✅ (핵심) 오른쪽에 colorbar 자리 확보하려고 가로폭 약간 넓히고 right margin 확보
    fig = plt.figure(figsize=(6.6, 5.2))
    ax = fig.add_subplot(111, projection="3d")

    sc = ax.scatter(Z3[:, 0], Z3[:, 1], Z3[:, 2], c=t_sec, s=12, alpha=0.9)
    ax.plot(Z3[:, 0], Z3[:, 1], Z3[:, 2], color="k", alpha=0.15, linewidth=0.8)

    ax.scatter(Z3[0, 0],  Z3[0, 1],  Z3[0, 2],  marker="o", s=70, edgecolor="k", linewidth=0.8)
    ax.scatter(Z3[-1, 0], Z3[-1, 1], Z3[-1, 2], marker="^", s=90, edgecolor="k", linewidth=0.8)

    if title:
        ax.set_title(title, fontsize=14, pad=10)

    ax.set_xlabel("PC1", fontsize=10)
    ax.set_ylabel("PC2", fontsize=10)

    # ✅ (겹침 방지 1) zlabel을 axis쪽으로 살짝 “당기기” (labelpad 음수)
    ax.set_zlabel("PC3", fontsize=10, labelpad=-2)

    # ✅ (모양이 더 잘 보이게) view 각도 고정 (원하면 값만 바꿔서 취향 맞추면 됨)
    ax.view_init(elev=18, azim=-62)

    ax.tick_params(axis='x', which='major', pad=-2)
    ax.tick_params(axis='y', which='major', pad=-2)
    ax.tick_params(axis='z', which='major', pad=-2)

    ax.grid(False)

    # ✅ (겹침 방지 2: 결정타) colorbar를 별도 축(cax)에 배치해서 PC3 라벨 영역과 분리
    #    right=0.80 → 오른쪽 20%를 비워두고, 그 공간에 cbar를 강제 배치
    fig.subplots_adjust(right=0.80)
    cax = fig.add_axes([0.84, 0.15, 0.03, 0.70])  # [left, bottom, width, height] in figure coords
    cbar = fig.colorbar(sc, cax=cax)
    cbar.ax.tick_params(labelsize=9)
    cbar.set_label("time (sec)", fontsize=12)

    if save_path is not None:
        _ensure_dir(os.path.dirname(save_path))
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# ✅ (3) NEW: Raw signal + predicted rep_rate overlay (short segment)
#     - colors slightly different
#     - smaller legend
#     - label text updates
# ---------------------------------------------------------------------
def plot_raw_and_rep_rate_overlay_short_segment(
    x_np,                 # (T,C) z-scored input
    rep_rate,             # (T,) predicted rep_rate_t
    fs,
    title="",
    seg_sec=6.0,
    channel_idx=0,
    select_mode="max_rate",   # "max_rate" or "center"
    # ✅ smoothing controls
    smooth_sigma_raw=0.8,     # raw는 약하게
    smooth_sigma_rate=2.0,    # rep_rate는 조금 더
    clip_rate_nonneg=True,
    save_path=None,
    dpi=600
):
    x_np = np.asarray(x_np, dtype=np.float32)
    rep_rate = np.asarray(rep_rate, dtype=np.float32)

    T = x_np.shape[0]
    C = x_np.shape[1]
    ch = int(np.clip(channel_idx, 0, C - 1))

    L = int(round(seg_sec * fs))
    L = max(10, min(L, T))

    if T <= L:
        st, ed = 0, T
    else:
        if select_mode == "center":
            st = max(0, (T - L) // 2)
            ed = st + L
        else:
            # segment where rep_rate is high (rolling mean peak)
            rr_for_pick = gaussian_filter1d(rep_rate, sigma=max(0.5, smooth_sigma_rate))
            kernel = np.ones(L, dtype=np.float32) / float(L)
            roll = np.convolve(rr_for_pick, kernel, mode="valid")
            best = int(np.argmax(roll))
            st = best
            ed = st + L

    raw_seg = x_np[st:ed, ch]
    rr_seg = rep_rate[st:ed]

    # ✅ smoothing (visual only)
    if smooth_sigma_raw and smooth_sigma_raw > 0:
        raw_seg_plot = gaussian_filter1d(raw_seg, sigma=smooth_sigma_raw)
    else:
        raw_seg_plot = raw_seg

    if smooth_sigma_rate and smooth_sigma_rate > 0:
        rr_seg_plot = gaussian_filter1d(rr_seg, sigma=smooth_sigma_rate)
    else:
        rr_seg_plot = rr_seg

    if clip_rate_nonneg:
        rr_seg_plot = np.clip(rr_seg_plot, 0.0, None)

    t = np.arange(st, ed, dtype=np.float32) / float(fs)

    fig = plt.figure(figsize=(11.5, 5))
    ax1 = fig.add_subplot(1, 1, 1)

    # ✅ raw: 조금 얇고 부드럽게
    ax1.plot(t, raw_seg_plot, linewidth=1.85, alpha=0.90, color="tab:blue", label="Raw (z-scored)")
    ax1.set_xlabel("Time (sec)", fontsize=19)
    ax1.set_ylabel("Raw signal", fontsize=19)
    ax1.grid(True, linestyle="--", alpha=0.25)

    ax2 = ax1.twinx()

    # ✅ rep_rate: 더 두껍고 부드럽게
    ax2.plot(t, rr_seg_plot, linewidth=2.4, alpha=0.95, color="tab:orange", label="Predicted (Rep rate)")
    ax2.set_ylabel("Rep rate (reps/s)", fontsize=19)

    if title:
        ax1.set_title(title, fontsize=19, pad=6)

    # legend (merge) - 살짝 작게
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, loc="upper right", fontsize=15, frameon=True)

    ax1.tick_params(axis="both", labelsize=14)
    ax2.tick_params(axis="y", labelsize=14)

    plt.tight_layout()
    if save_path is not None:
        _ensure_dir(os.path.dirname(save_path))
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)



# ---------------------------------------------------------------------
# 7) Main (LOSO) - ALL activities + requested plots (PNG export)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "out_dir": "./outputs_loso_activity_plots",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        "TARGET_ACTIVITIES_MAP": {
            8:  'Knees bending',
            10:  'Jogging',
            12: 'Jump front & back',
        },

        "ACT_FEATURE_MAP": {
            8:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            10:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            12: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # ✅ Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        # temperature
        "tau": 1.0,

        # plotting controls
        "plot_max_T": 2000,     # per-fold plot downsample cap
        "dpi": 600,

        # ✅ overlay controls
        "overlay_seg_sec": 6.0,        # short segment length (sec)
        "overlay_channel_idx": 0,      # which channel to show from x_np (z-scored)
        "overlay_select_mode": "max_rate",  # "max_rate" or "center"

        # Count-only labels
        "ALL_LABELS": [
            # --- act 8 ---
            ("subject1", 8, 20), ("subject2", 8, 21), ("subject3", 8, 21), ("subject4", 8, 19), ("subject5", 8, 20),
            ("subject6", 8, 20), ("subject7", 8, 21), ("subject8", 8, 21), ("subject9", 8, 21), ("subject10", 8, 21),

            # --- act 10 ---
            ("subject1", 10, 157), ("subject2", 10, 161), ("subject3", 10, 154), ("subject4", 10, 154), ("subject5", 10, 160),
            ("subject6", 10, 156), ("subject7", 10, 153), ("subject8", 10, 160), ("subject9", 10, 166), ("subject10", 10, 156),

            # --- act 12 ---
            ("subject1", 12, 20), ("subject2", 12, 22), ("subject3", 12, 21), ("subject4", 12, 21), ("subject5", 12, 20),
            ("subject6", 12, 21), ("subject7", 12, 19), ("subject8", 12, 20), ("subject9", 12, 20), ("subject10", 12, 20),
        ],
    }

    # -----------------------------------------------------------------
    # ✅ Exact paper titles (as requested)
    # -----------------------------------------------------------------
    TITLE_PHASE = ""
    TITLE_PCA3D = ""
    TITLE_RAW_OVERLAY = ""

    _ensure_dir(CONFIG["out_dir"])

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    print("\n" + "-" * 90)
    print(" >>> Starting LOSO (count-only, K-auto) + WINDOWING — ALL ACTIVITIES (PNG export)")
    print("-" * 90)

    all_results = {}

    # Evaluate each activity independently (LOSO within activity)
    for act_id, act_name in CONFIG["TARGET_ACTIVITIES_MAP"].items():
        labels_act = [x for x in CONFIG["ALL_LABELS"] if x[1] == act_id]
        if len(labels_act) == 0:
            print(f"\n[Skip activity] act_id={act_id} ({act_name}): no labels in ALL_LABELS.")
            continue

        subjects_act = sorted(list(set([x[0] for x in labels_act])))

        print("\n" + "=" * 90)
        print(f" [Activity] {act_id}: {act_name} | #subjects={len(subjects_act)}")
        print("=" * 90)

        act_out_dir = _ensure_dir(os.path.join(CONFIG["out_dir"], f"act{act_id}_{_safe_filename(act_name)}"))

        loso_maes = []

        for fold_idx, test_subj in enumerate(subjects_act):
            set_strict_seed(CONFIG["seed"])

            fold_out_dir = _ensure_dir(os.path.join(act_out_dir, f"fold{fold_idx+1:02d}_{_safe_filename(test_subj)}"))

            train_labels = [x for x in labels_act if x[0] != test_subj]
            test_labels  = [x for x in labels_act if x[0] == test_subj]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

            if not test_trials:
                print(f"[Skip] Activity {act_id} | Fold {fold_idx+1}: {test_subj} has no data.")
                continue

            # ✅ TRAIN only: trial -> windows
            train_data = trial_list_to_windows(
                train_trials,
                fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"],
                stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )

            # TEST: trial 그대로
            test_data = test_trials

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_data[0]['data'].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for epoch in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval()

            # ---- fold test ----
            fold_mae = 0.0
            fold_res_str = ""

            test_gt = None
            test_pred = None
            test_diff = None
            test_khat = None
            test_entropy = None

            test_x_np = None
            test_rep_rate = None
            test_phase_p = None
            test_z_np = None
            test_meta = None

            for item in test_data:
                x_np = item["data"]  # (T,C)

                # pred: windowing inference
                count_pred_win, _win_rates = predict_count_by_windowing(
                    model,
                    x_np=x_np,
                    fs=CONFIG["fs"],
                    win_sec=CONFIG["win_sec"],
                    stride_sec=CONFIG["stride_sec"],
                    device=device,
                    tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )

                count_gt = float(item["count"])
                abs_err = abs(count_pred_win - count_gt)
                fold_mae += abs_err
                fold_res_str += f"[Pred(win): {count_pred_win:.1f} / GT: {count_gt:.0f}]"

                # full-trial 1회 forward (rep_rate, phase_p, z 포함)
                x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
                with torch.no_grad():
                    _, z_full, _, aux = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

                phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()      # (T,K)
                rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()  # (T,)
                z_np_full = z_full.squeeze(0).detach().cpu().numpy()            # (T,D)

                k_hat = float(aux["k_hat"].item())
                ent = compute_phase_entropy_mean(phase_p)

                test_gt = count_gt
                test_pred = float(count_pred_win)
                test_diff = float(count_pred_win - count_gt)
                test_khat = k_hat
                test_entropy = ent

                test_x_np = x_np
                test_rep_rate = rep_rate
                test_phase_p = phase_p
                test_z_np = z_np_full
                test_meta = item.get("meta", f"{test_subj}_{act_name}")

            fold_mae /= len(test_data)
            loso_maes.append(fold_mae)

            print(f"  Fold {fold_idx+1:2d} | Test: {test_subj:<9s} | MAE: {fold_mae:.2f} | {fold_res_str}")
            if (test_gt is not None) and (test_pred is not None):
                print(
                    f"    [Fold Summary] act={act_id} | {test_subj} | GT={test_gt:.0f} | Pred(win)={test_pred:.2f} | "
                    f"Diff={test_diff:+.2f} | k_hat(full)={test_khat:.2f} | phase_entropy(full)={test_entropy:.3f}"
                )

            # =========================
            # ✅ Requested 3 plots (per fold) -> PNG
            #   1) Phase heatmap
            #   2) PCA3D
            #   3) Raw + rep_rate overlay (short segment)
            # =========================
            if (test_phase_p is not None) and (test_z_np is not None) and (test_rep_rate is not None) and (test_x_np is not None):
                # (1) Phase heatmap
                plot_phase_heatmap_and_dominant(
                    test_phase_p,
                    fs=CONFIG["fs"],
                    title=TITLE_PHASE,
                    max_T=int(CONFIG["plot_max_T"]),
                    save_path=os.path.join(fold_out_dir, "01_phase_heatmap.png"),
                    dpi=int(CONFIG["dpi"])
                )

                # (2) PCA3D
                plot_pca3d_z(
                    test_z_np,
                    fs=CONFIG["fs"],
                    title=TITLE_PCA3D,
                    max_T=int(CONFIG["plot_max_T"]),
                    save_path=os.path.join(fold_out_dir, "02_pca3d.png"),
                    dpi=int(CONFIG["dpi"])
                )

                # (3) Raw + rep_rate overlay (short segment)
                plot_raw_and_rep_rate_overlay_short_segment(
                    x_np=test_x_np,
                    rep_rate=test_rep_rate,
                    fs=CONFIG["fs"],
                    title=TITLE_RAW_OVERLAY,
                    seg_sec=float(CONFIG["overlay_seg_sec"]),
                    channel_idx=int(CONFIG["overlay_channel_idx"]),
                    select_mode=str(CONFIG["overlay_select_mode"]),
                    save_path=os.path.join(fold_out_dir, "03_raw_rep_rate_overlay.png"),
                    dpi=int(CONFIG["dpi"])
                )

        # store results
        if len(loso_maes) > 0:
            all_results[act_id] = loso_maes
            print("-" * 90)
            print(f" [Activity Result] {act_name} | Mean MAE: {np.mean(loso_maes):.3f} | Std: {np.std(loso_maes):.3f}")
            print("-" * 90)

    # final summary across activities
    print("\n" + "#" * 90)
    print(" >>> Summary (per activity)")
    print("#" * 90)
    for act_id, maes in all_results.items():
        print(f"  - act {act_id:>2d} ({CONFIG['TARGET_ACTIVITIES_MAP'][act_id]}): mean={np.mean(maes):.3f}, std={np.std(maes):.3f}")
    print("#" * 90)
    print(f"[Saved plots] {os.path.abspath(CONFIG['out_dir'])}")


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

------------------------------------------------------------------------------------------
 >>> Starting LOSO (count-only, K-auto) + WINDOWING — ALL ACTIVITIES (PNG export)
------------------------------------------------------------------------------------------

 [Activity] 8: Knees bending | #subjects=10
  Fold  1 | Test: subject1  | MAE: 12.18 | [Pred(win): 32.2 / GT: 20]
    [Fold Summary] act=8 | subject1 | GT=20 | Pred(win)=32.18 | Diff=+12.18 | k_hat(full)=1.04 | phase_entropy(full)=0.084
  Fold  2 | Test: subject10 | MAE: 2.31 | [Pred(win): 18.7 / GT: 21]
    [Fold Summary] act=8 | subject10 | GT=21 | Pred(win)=18.69 | Diff=-2.31 | k_hat(full)=1.02 | phase_entropy(full)=0.059
  Fold  3 | Test: subject2  | MAE: 1.54 | [Pred(win): 22.5 / GT: 21]
    [Fold Summary] act=8 | subject2 | GT=21 | Pred(win)=22.54 | Diff=+1.54 | k_hat(full)=1.05 | phase_entropy(full)=0.103
  Fold  4 

In [1]:
# =========================
# ✅ UPDATED: only 2 plots per scenario
#   1) Window-level rep_rate mean±std (+ boundary + A/B shading)
#   2) Latent trajectory PCA2 (scatter + mean trajectory line + boundary marker + start/end + std circles)
#   - NO titles
#   - Bigger fonts
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns  # kept
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d  # kept
from matplotlib.patches import Circle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# ✅ Save helpers
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)

def _safe_fname(s: str):
    s = str(s)
    for ch in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\n', '\t']:
        s = s.replace(ch, '_')
    s = s.strip().replace(' ', '_')
    while '__' in s:
        s = s.replace('__', '_')
    return s


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# ✅ Mixed A-B builder
# ---------------------------------------------------------------------
def get_single_trial_from_full_data(subj, act_id, gt_count, full_data, target_map, feature_map, normalize=True):
    act_name = target_map.get(act_id)
    feats = feature_map.get(act_id)

    if subj not in full_data or act_name not in full_data[subj]:
        return None

    raw_df = full_data[subj][act_name][feats]
    x = raw_df.values.astype(np.float32)

    if normalize:
        mean = x.mean(axis=0)
        std = x.std(axis=0) + 1e-6
        x = (x - mean) / std

    return {
        "data": x,
        "count": float(gt_count),
        "meta": f"{subj}_{act_name}"
    }


def build_mixed_ab_trial(subj, actA_id, actB_id, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = gtA + gtB

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    boundary = int(xA_raw.shape[0])

    x_mix_raw = np.concatenate([xA_raw, xB_raw], axis=0).astype(np.float32)

    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA = (xA_raw - mean) / std
    xB = (xB_raw - mean) / std

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    return {
        "data": x_mix,
        "count": float(gt_total),
        "meta": f"{subj}__{actA_name}__TO__{actB_name}",
        "boundary": boundary,
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A": xA,
        "data_B": xB,
        "T_A": int(xA.shape[0]),
        "T_B": int(xB.shape[0]),
    }


# ---------------------------------------------------------------------
# 2.5) Windowing
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# Helpers for scenario-avg plots
# ---------------------------------------------------------------------
def get_window_rate_curve(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]

    if T <= win_len:
        pred_count, rates = predict_count_by_windowing(
            model, x_np, fs, win_sec, stride_sec, device, tau=tau, batch_size=batch_size
        )
        t_cent = np.array([0.5 * (T / float(fs))], dtype=np.float32)
        return t_cent, rates.astype(np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())
    rates = np.concatenate(rates, axis=0).astype(np.float32)

    centers = np.array([(st + 0.5 * win_len) / float(fs) for st in starts], dtype=np.float32)
    return centers, rates


def resample_1d(y, new_len):
    y = np.asarray(y, dtype=np.float32)
    if y.size == 0:
        return np.zeros((new_len,), dtype=np.float32)
    if y.size == 1:
        return np.full((new_len,), float(y[0]), dtype=np.float32)
    x_old = np.linspace(0.0, 1.0, num=y.size, dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    return np.interp(x_new, x_old, y).astype(np.float32)


def resample_2d(Y, new_len):
    Y = np.asarray(Y, dtype=np.float32)
    if Y.shape[0] == 0:
        return np.zeros((new_len, Y.shape[1]), dtype=np.float32)
    if Y.shape[0] == 1:
        return np.repeat(Y, repeats=new_len, axis=0).astype(np.float32)
    x_old = np.linspace(0.0, 1.0, num=Y.shape[0], dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    out = []
    for d in range(Y.shape[1]):
        out.append(np.interp(x_new, x_old, Y[:, d]))
    return np.stack(out, axis=1).astype(np.float32)


def global_pca2(Z_all):
    Z_all = np.asarray(Z_all, dtype=np.float32)
    mu = Z_all.mean(axis=0, keepdims=True)
    Zc = Z_all - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    V2 = Vt[:2].astype(np.float32)
    return mu.squeeze(0).astype(np.float32), V2


# ---------------------------------------------------------------------
# ✅ UPDATED PLOTTER: only (1) and (2), no titles, bigger fonts
# ---------------------------------------------------------------------
def plot_scenario_mean_std_two_plots(
    save_dir,
    scen_name,
    win_t, win_mean, win_std,
    lat_xy_mean, lat_xy_std,   # (std는 이제 안 쓰지만 signature 유지)
    n_subjects,
    dpi=600,
):
    _ensure_dir(save_dir if save_dir is not None else "scenario_viz")
    out_dir = save_dir if save_dir is not None else "scenario_viz"
    scen_tag = _safe_fname(scen_name)

    # -------------------------
    # (1) Window-level rep_rate mean±std (+ boundary + A/B shading)
    # - A/B 글자 삭제
    # - figsize/폰트 크게
    # -------------------------
    fig1 = plt.figure(figsize=(26.0, 9.2))  # ✅ bigger
    ax = fig1.gca()

    ax.axvspan(0.0, 1.0, alpha=0.06, color="k", lw=0)
    ax.axvspan(1.0, 2.0, alpha=0.03, color="k", lw=0)

    ax.plot(win_t, win_mean, linewidth=6.4, label="Window rep_rate (mean)")
    ax.fill_between(win_t, win_mean - win_std, win_mean + win_std, alpha=0.22, label="±1 std")
    ax.axvline(1.0, linestyle="--", linewidth=6.0, label="Boundary (A→B)")

    ax.set_xlabel("Normalized time (A:0→1, B:1→2)", fontsize=46, labelpad=10)
    ax.set_ylabel("rep_rate (reps/s)", fontsize=46, labelpad=10)
    ax.tick_params(axis="both", labelsize=37)

    ax.legend(fontsize=38, frameon=True, loc="upper left")
    fig1.tight_layout()
    fig1.savefig(os.path.join(out_dir, f"{scen_tag}__1_win_rep_rate.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig1)

    # -------------------------
    # (2) Latent trajectory PCA2
    # - std circle 삭제
    # - trajectory line 더 얇게
    # -------------------------
    fig2 = plt.figure(figsize=(8.8, 8.8))
    ax = fig2.gca()

    T = lat_xy_mean.shape[0]
    t_lat = np.linspace(0.0, 2.0, num=T, dtype=np.float32)

    sc = ax.scatter(lat_xy_mean[:, 0], lat_xy_mean[:, 1], c=t_lat, s=22, alpha=0.95)

    ax.set_xlabel("PC1", fontsize=24, labelpad=10)
    ax.set_ylabel("PC2", fontsize=24, labelpad=10)
    ax.tick_params(axis="both", labelsize=18)
    ax.set_aspect("equal", adjustable="box")

    # square bounds 유지
    x = lat_xy_mean[:, 0]; y = lat_xy_mean[:, 1]
    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    cx = 0.5 * (xmin + xmax); cy = 0.5 * (ymin + ymax)
    rx = 0.5 * (xmax - xmin); ry = 0.5 * (ymax - ymin)
    r = max(rx, ry) * 1.12 + 1e-6
    ax.set_xlim(cx - r, cx + r)
    ax.set_ylim(cy - r, cy + r)

    cbar = fig2.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("normalized time (0→2)", fontsize=20, labelpad=10)
    cbar.ax.tick_params(labelsize=16)

    fig2.tight_layout()
    fig2.savefig(os.path.join(out_dir, f"{scen_tag}__2_latent_pca2.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)


# ---------------------------------------------------------------------
# Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# ---------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)
        z = z.transpose(1, 2)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# ---------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# Main (Scenario-wise LOSO)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        "TARGET_ACTIVITIES_MAP": {
            6:  'Waist bends forward',
            7:  'Frontal elevation of arms',
            8:  'Knees bending',
            10:  'Jogging',
            11: 'Running',
            12: 'Jump front & back',
        },

        "ACT_FEATURE_MAP": {
            6:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            7:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            8:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            10: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            11: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            12: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        "GT_BY_ACT": {
            6:  {"subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                 "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20},
            7:  {"subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                 "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20},
            8:  {"subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                 "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21},
            10: {"subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                 "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156},
            11: {"subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                 "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172},
            12: {"subject1": 20, "subject2": 22, "subject3": 21, "subject4": 21, "subject5": 20,
                 "subject6": 21, "subject7": 19, "subject8": 20, "subject9": 20, "subject10": 20},
        },

        "MIX_SCENARIOS": [
            ("Knees bending-Frontal elevation of arms", 8, 7),
            ("Waist bends forward-Jump front & back", 6, 12),
            ("Jogging-Running", 10, 11),
        ],
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    VIZ_DIR = "scenario_viz"
    os.makedirs(VIZ_DIR, exist_ok=True)

    # resample lengths (fixed for averaging)
    NW = 60     # window-rate points per segment (A,B)
    NL = 200    # latent points per segment (A,B)

    for scen_idx, (scen_name, actA_id, actB_id) in enumerate(CONFIG.get("MIX_SCENARIOS", []), start=1):
        TRAIN_ACT_ID = actA_id

        print("\n" + "=" * 100)
        print(f"1. 시나리오{scen_idx}: {scen_name}")
        print("=" * 100)

        loso_results = []
        scenario_records = []
        scenario_diffs = []
        scenario_diffs_A = []
        scenario_diffs_B = []

        # ✅ buffers for TWO plots only
        buf_win_concat = []     # (2*NW,)
        buf_z_concat = []       # list of (Tmix,D)
        per_subject_latent = [] # list of {"zA","zB"}

        for fold_idx, test_subj in enumerate(subjects):
            set_strict_seed(CONFIG["seed"])

            gt_train_map = CONFIG["GT_BY_ACT"][TRAIN_ACT_ID]
            train_labels = [(s, TRAIN_ACT_ID, gt_train_map[s]) for s in subjects if s != test_subj]
            test_labels  = [(test_subj, TRAIN_ACT_ID, gt_train_map[test_subj])]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            if not test_trials or not train_trials:
                continue

            train_data = trial_list_to_windows(
                train_trials, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_data[0]['data'].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for epoch in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval
            model.eval()

            # (A) LOSO on TRAIN_ACT_ID
            item = test_trials[0]
            x_np = item["data"]
            count_gt = float(item["count"])
            count_pred_win, _ = predict_count_by_windowing(
                model, x_np=x_np,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            fold_mae = float(abs(count_pred_win - count_gt))
            loso_results.append(fold_mae)

            # (B) Scenario eval
            mixed_item = build_mixed_ab_trial(test_subj, actA_id, actB_id, CONFIG, full_data)
            if mixed_item is None:
                continue

            x_mix = mixed_item["data"]
            boundary = mixed_item["boundary"]
            gt_total = float(mixed_item["count"])
            md = mixed_item["meta_detail"]

            pred_total, _ = predict_count_by_windowing(
                model, x_np=x_mix,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = float(pred_total - gt_total)
            mae = float(abs(diff))
            scenario_diffs.append(diff)

            # A/B split MAE
            xA = mixed_item["data_A"]
            xB = mixed_item["data_B"]
            gtA = float(md["gtA"])
            gtB = float(md["gtB"])

            pred_A, _ = predict_count_by_windowing(
                model, x_np=xA,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            pred_B, _ = predict_count_by_windowing(
                model, x_np=xB,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff_A = float(pred_A - gtA)
            diff_B = float(pred_B - gtB)
            mae_A = float(abs(diff_A))
            mae_B = float(abs(diff_B))
            scenario_diffs_A.append(diff_A)
            scenario_diffs_B.append(diff_B)

            # full forward for z (latent) (+ keep entropy/k_hat log if you want)
            x_tensor = torch.tensor(x_mix, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, z_m, _, aux_m = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            phase_p_m = aux_m["phase_p"].squeeze(0).detach().cpu().numpy()  # (Tmix,K)
            k_hat_m = float(aux_m["k_hat"].item())
            ent_m = compute_phase_entropy_mean(phase_p_m)

            z_td = z_m.squeeze(0).detach().cpu().numpy()  # (Tmix,D)

            scenario_records.append({
                "subj": test_subj,
                "line": (
                    f"[Scenario] {scen_name} | {test_subj} | "
                    f"{md['actA_name']}({md['gtA']:.0f}) -> {md['actB_name']}({md['gtB']:.0f}) | "
                    f"GT_total={gt_total:.0f} | Pred_total(win)={pred_total:.2f} | Diff_total={diff:+.2f} | "
                    f"MAE_total={mae:.2f} | "
                    f"Pred_A={pred_A:.2f} (Diff_A={diff_A:+.2f}, MAE_A={mae_A:.2f}) | "
                    f"Pred_B={pred_B:.2f} (Diff_B={diff_B:+.2f}, MAE_B={mae_B:.2f}) | "
                    f"k_hat(full)={k_hat_m:.2f} | ent(full)={ent_m:.3f} | boundary={boundary}"
                )
            })

            # (1) window-level rep_rate curves: A/B separately -> resample -> concat
            _, rA_w = get_window_rate_curve(
                model, xA, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rB_w = get_window_rate_curve(
                model, xB, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            rA_rs = resample_1d(rA_w, NW)
            rB_rs = resample_1d(rB_w, NW)
            buf_win_concat.append(np.concatenate([rA_rs, rB_rs], axis=0))

            # (2) latent: store raw z split for global PCA later
            zA = z_td[:boundary, :]
            zB = z_td[boundary:, :]
            buf_z_concat.append(z_td)
            per_subject_latent.append({"zA": zA, "zB": zB})

        # -------------------------
        # Print summaries (same as your original)
        # -------------------------
        print("-" * 100)
        if len(loso_results) > 0:
            print("TRain 결과 ->")
            print("-" * 100)
            print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
            print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
            print("-" * 100)
        else:
            print("TRain 결과 -> (no folds computed)")
            print("-" * 100)

        subj2line = {r["subj"]: r["line"] for r in scenario_records}
        for s in subjects:
            if s in subj2line:
                print(subj2line[s])

        print("\n" + "-" * 100)
        print(f"시나리오{scen_idx} 결과")
        print("-" * 100)

        if len(scenario_diffs) == 0:
            print(f"[{scen_name}] No samples (all skipped). Fill GT_BY_ACT for required activities.")
        else:
            diffs = np.array(scenario_diffs, dtype=np.float32)
            mae = float(np.mean(np.abs(diffs)))
            rmse = float(np.sqrt(np.mean(diffs ** 2)))
            print(f"[TOTAL] [{scen_name}] N={len(diffs):2d} | MAE={mae:.3f} | RMSE={rmse:.3f} | mean(diff)={diffs.mean():+.3f} | std(diff)={diffs.std():.3f}")

            diffsA = np.array(scenario_diffs_A, dtype=np.float32)
            maeA = float(np.mean(np.abs(diffsA)))
            rmseA = float(np.sqrt(np.mean(diffsA ** 2)))
            print(f"[A]     [{scen_name}] N={len(diffsA):2d} | MAE={maeA:.3f} | RMSE={rmseA:.3f} | mean(diff)={diffsA.mean():+.3f} | std(diff)={diffsA.std():.3f}")

            diffsB = np.array(scenario_diffs_B, dtype=np.float32)
            maeB = float(np.mean(np.abs(diffsB)))
            rmseB = float(np.sqrt(np.mean(diffsB ** 2)))
            print(f"[B]     [{scen_name}] N={len(diffsB):2d} | MAE={maeB:.3f} | RMSE={rmseB:.3f} | mean(diff)={diffsB.mean():+.3f} | std(diff)={diffsB.std():.3f}")

        print("-" * 100)

        # -------------------------
        # ✅ TWO-PLOT scenario-avg visualization
        # -------------------------
        n_ok = len(buf_win_concat)
        if n_ok == 0:
            print("[Viz Skip] No scenario samples collected for averaging.")
            continue

        # (1) window-level mean±std on normalized axis [0,2]
        W = np.stack(buf_win_concat, axis=0)  # (N, 2*NW)
        win_mean = W.mean(axis=0)
        win_std  = W.std(axis=0)
        win_t = np.linspace(0.0, 2.0, num=2*NW, dtype=np.float32)

        # (2) latent: global PCA basis from all z
        Z_all = np.concatenate(buf_z_concat, axis=0)  # (sumT, D)
        mu, V2 = global_pca2(Z_all)

        # project each subject, resample A/B to NL then concat -> (2*NL,2)
        lat_list = []
        for d in per_subject_latent:
            zA = d["zA"]; zB = d["zB"]
            pcA = (zA - mu) @ V2.T
            pcB = (zB - mu) @ V2.T
            pcA_rs = resample_2d(pcA, NL)
            pcB_rs = resample_2d(pcB, NL)
            lat_list.append(np.concatenate([pcA_rs, pcB_rs], axis=0))

        LAT = np.stack(lat_list, axis=0)  # (N,2*NL,2)
        lat_xy_mean = LAT.mean(axis=0)
        lat_xy_std  = LAT.std(axis=0)

        plot_scenario_mean_std_two_plots(
            save_dir=VIZ_DIR,
            scen_name=scen_name,
            win_t=win_t, win_mean=win_mean, win_std=win_std,
            lat_xy_mean=lat_xy_mean, lat_xy_std=lat_xy_std,
            n_subjects=n_ok,
            dpi=600
        )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

1. 시나리오1: Knees bending-Frontal elevation of arms
----------------------------------------------------------------------------------------------------
TRain 결과 ->
----------------------------------------------------------------------------------------------------
 >>> Final LOSO Result (Average MAE): 4.274
 >>> Standard Deviation: 3.012
----------------------------------------------------------------------------------------------------
[Scenario] Knees bending-Frontal elevation of arms | subject1 | Knees bending(20) -> Frontal elevation of arms(20) | GT_total=40 | Pred_total(win)=66.34 | Diff_total=+26.34 | MAE_total=26.34 | Pred_A=47.42 (Diff_A=+27.42, MAE_A=27.42) | Pred_B=18.71 (Diff_B=-1.29, MAE_B=1.29) | k_hat(full)=1.02 | ent(full)=0.059 | boundary=3379
[Scenario] Knees bending-Frontal elevation of arms | subject2 | Knees bending(21) -> Frontal elevation of arms(20) | GT_total

In [2]:
import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns  # kept
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d  # kept

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# ✅ Save helpers (needed)
# ---------------------------------------------------------------------
def _ensure_dir(path: str):
    if path is None or path == "":
        return
    os.makedirs(path, exist_ok=True)

def _safe_fname(s: str):
    s = str(s)
    for ch in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', '\n', '\t']:
        s = s.replace(ch, '_')
    s = s.strip().replace(' ', '_')
    while '__' in s:
        s = s.replace('__', '_')
    return s


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score normalize per-trial
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,              # (T, C)
                'count': float(gt_count),      # trial total count
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# ✅ Mixed builder helpers
# ---------------------------------------------------------------------
def get_single_trial_from_full_data(subj, act_id, gt_count, full_data, target_map, feature_map, normalize=True):
    act_name = target_map.get(act_id)
    feats = feature_map.get(act_id)

    if subj not in full_data or act_name not in full_data[subj]:
        return None

    raw_df = full_data[subj][act_name][feats]
    x = raw_df.values.astype(np.float32)

    if normalize:
        mean = x.mean(axis=0)
        std = x.std(axis=0) + 1e-6
        x = (x - mean) / std

    return {
        "data": x,
        "count": float(gt_count),
        "meta": f"{subj}_{act_name}"
    }


# ---------------------------------------------------------------------
# ✅ Mixed A-B builder (kept; not used in ABAB run but left intact)
# ---------------------------------------------------------------------
def build_mixed_ab_trial(subj, actA_id, actB_id, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = gtA + gtB

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    boundary = int(xA_raw.shape[0])

    x_mix_raw = np.concatenate([xA_raw, xB_raw], axis=0).astype(np.float32)

    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA = (xA_raw - mean) / std
    xB = (xB_raw - mean) / std

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    return {
        "data": x_mix,
        "count": float(gt_total),
        "meta": f"{subj}__{actA_name}__TO__{actB_name}",
        "boundary": boundary,
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A": xA,
        "data_B": xB,
        "T_A": int(xA.shape[0]),
        "T_B": int(xB.shape[0]),
    }


# ---------------------------------------------------------------------
# ✅ Mixed A-B-A-B builder
#     - raw concat: A | B | A | B
#     - mixed 전체로 1회 z-score
#     - boundaries: [TA, TA+TB, TA+TB+TA]
#     - GT_total = 2*gtA + 2*gtB
# ---------------------------------------------------------------------
def build_mixed_abab_trial(subj, actA_id, actB_id, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = 2.0 * gtA + 2.0 * gtB

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    TA = int(xA_raw.shape[0])
    TB = int(xB_raw.shape[0])

    b1 = TA
    b2 = TA + TB
    b3 = TA + TB + TA

    x_mix_raw = np.concatenate([xA_raw, xB_raw, xA_raw, xB_raw], axis=0).astype(np.float32)

    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA1 = (xA_raw - mean) / std
    xB1 = (xB_raw - mean) / std
    xA2 = xA1.copy()
    xB2 = xB1.copy()

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    return {
        "data": x_mix,                   # (TA+TB+TA+TB, C)
        "count": float(gt_total),
        "meta": f"{subj}__{actA_name}__TO__{actB_name}__TO__{actA_name}__TO__{actB_name}",
        "boundaries": [b1, b2, b3],
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A1": xA1,
        "data_B1": xB1,
        "data_A2": xA2,
        "data_B2": xB2,
        "T_A": TA,
        "T_B": TB,
    }


# ---------------------------------------------------------------------
# 2.5) ✅ Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# ✅ Helpers for scenario-avg plots
# ---------------------------------------------------------------------
def get_window_rate_curve(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]

    if T <= win_len:
        pred_count, rates = predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=tau, batch_size=batch_size)
        t_cent = np.array([0.5 * (T / float(fs))], dtype=np.float32)
        return t_cent, rates.astype(np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())
    rates = np.concatenate(rates, axis=0).astype(np.float32)

    centers = np.array([(st + 0.5 * win_len) / float(fs) for st in starts], dtype=np.float32)
    return centers, rates


def resample_1d(y, new_len):
    y = np.asarray(y, dtype=np.float32)
    if y.size == 0:
        return np.zeros((new_len,), dtype=np.float32)
    if y.size == 1:
        return np.full((new_len,), float(y[0]), dtype=np.float32)
    x_old = np.linspace(0.0, 1.0, num=y.size, dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    return np.interp(x_new, x_old, y).astype(np.float32)


def resample_2d(Y, new_len):
    Y = np.asarray(Y, dtype=np.float32)
    if Y.shape[0] == 0:
        return np.zeros((new_len, Y.shape[1]), dtype=np.float32)
    if Y.shape[0] == 1:
        return np.repeat(Y, repeats=new_len, axis=0).astype(np.float32)
    x_old = np.linspace(0.0, 1.0, num=Y.shape[0], dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    out = []
    for d in range(Y.shape[1]):
        out.append(np.interp(x_new, x_old, Y[:, d]))
    return np.stack(out, axis=1).astype(np.float32)


def global_pca2(Z_all):
    """
    Z_all: (N,D)
    returns: mean (D,), V2 (2,D)  such that proj = (Z-mean) @ V2.T
    """
    Z_all = np.asarray(Z_all, dtype=np.float32)
    mu = Z_all.mean(axis=0, keepdims=True)
    Zc = Z_all - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    V2 = Vt[:2].astype(np.float32)  # (2,D)
    return mu.squeeze(0).astype(np.float32), V2


# ---------------------------------------------------------------------
# ✅ NEW: Two-plot visualization for ABAB
#   - only (1) rep_rate mean±std + boundaries + shading
#   - only (2) PCA2 scatter (no line, no boundary markers, no std circles)
# ---------------------------------------------------------------------
def plot_scenario_mean_std_two_plots(
    save_dir,
    scen_name,
    win_t, win_mean, win_std,
    lat_xy_mean, lat_xy_std,   # (std unused but keep signature)
    n_subjects,
    dpi=600,
):
    _ensure_dir(save_dir if save_dir is not None else "scenario_viz")
    out_dir = save_dir if save_dir is not None else "scenario_viz"
    scen_tag = _safe_fname(scen_name)

    # -------------------------
    # (1) Window-level rep_rate mean±std (+ boundaries + shading)
    # - A/B 글자 없음
    # - figsize/폰트 크게
    # -------------------------
    fig1 = plt.figure(figsize=(26.0, 9.2))
    ax = fig1.gca()

    # ABAB: 0~1,1~2,2~3,3~4
    ax.axvspan(0.0, 1.0, alpha=0.06, color="k", lw=0)
    ax.axvspan(1.0, 2.0, alpha=0.03, color="k", lw=0)
    ax.axvspan(2.0, 3.0, alpha=0.06, color="k", lw=0)
    ax.axvspan(3.0, 4.0, alpha=0.03, color="k", lw=0)

    ax.plot(win_t, win_mean, linewidth=6.4, label="Window rep_rate (mean)")
    ax.fill_between(win_t, win_mean - win_std, win_mean + win_std, alpha=0.22, label="±1 std")

    # boundaries at 1,2,3 (label once)
    ax.axvline(1.0, linestyle="--", linewidth=6.0, label="Boundary")
    ax.axvline(2.0, linestyle="--", linewidth=6.0, label="_nolegend_")
    ax.axvline(3.0, linestyle="--", linewidth=6.0, label="_nolegend_")

    ax.set_xlabel("Normalized time", fontsize=46, labelpad=10)
    ax.set_ylabel("rep_rate (reps/s)", fontsize=46, labelpad=10)
    ax.tick_params(axis="both", labelsize=37)

    ax.legend(fontsize=38, frameon=True, loc="upper left")
    fig1.tight_layout()
    fig1.savefig(os.path.join(out_dir, f"{scen_tag}__1_win_rep_rate.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig1)

    # -------------------------
    # (2) Latent trajectory PCA2
    # - std circle 삭제
    # - line 없음 (scatter only)
    # - boundary/start/end 없음
    # -------------------------
    fig2 = plt.figure(figsize=(8.8, 8.8))
    ax = fig2.gca()

    T = lat_xy_mean.shape[0]
    t_lat = np.linspace(0.0, 4.0, num=T, dtype=np.float32)

    sc = ax.scatter(lat_xy_mean[:, 0], lat_xy_mean[:, 1], c=t_lat, s=22, alpha=0.95)

    ax.set_xlabel("PC1", fontsize=24, labelpad=10)
    ax.set_ylabel("PC2", fontsize=24, labelpad=10)
    ax.tick_params(axis="both", labelsize=18)
    ax.set_aspect("equal", adjustable="box")

    # square bounds 유지
    x = lat_xy_mean[:, 0]; y = lat_xy_mean[:, 1]
    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    cx = 0.5 * (xmin + xmax); cy = 0.5 * (ymin + ymax)
    rx = 0.5 * (xmax - xmin); ry = 0.5 * (ymax - ymin)
    r = max(rx, ry) * 1.12 + 1e-6
    ax.set_xlim(cx - r, cx + r)
    ax.set_ylim(cy - r, cy + r)

    cbar = fig2.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("normalized time (0→4)", fontsize=20, labelpad=10)
    cbar.ax.tick_params(labelsize=16)

    fig2.tight_layout()
    fig2.savefig(os.path.join(out_dir, f"{scen_tag}__2_latent_pca2.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)

        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    # kept for training loss
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# ---------------------------------------------------------------------
# 5) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Main (Scenario-wise LOSO)  - ABAB
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            7: 'Frontal elevation of arms',
            8: 'Knees bending',
            10: 'Jogging',
            11: 'Running',
        },
        "ACT_FEATURE_MAP": {
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            10: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            11: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        "GT_BY_ACT": {
            7: {
                "subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20,
            },
            8: {
                "subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21,
            },
            10: {
                "subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156,
            },
            11: {
                "subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172,
            },
        },

        "MIX_SCENARIOS": [
            ("Frontal elevation of arms -> Knees bending (ABAB)", 7, 8),
            ("Knees bending -> Frontal elevation of arms (ABAB)", 8, 7),
            ("Jogging -> Running (ABAB)", 10, 11),
            ("Running -> Jogging (ABAB)", 11, 10),
        ],
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    VIZ_DIR = "scenario_viz"
    os.makedirs(VIZ_DIR, exist_ok=True)

    # resample lengths (fixed for averaging)  # per-segment
    NW = 60     # window-rate points per segment
    NL = 200    # latent points per segment

    for scen_idx, (scen_name, actA_id, actB_id) in enumerate(CONFIG.get("MIX_SCENARIOS", []), start=1):
        TRAIN_ACT_ID = actA_id

        print("\n" + "=" * 100)
        print(f"1. 시나리오{scen_idx}: {scen_name}")
        print("=" * 100)

        loso_results = []
        scenario_records = []
        scenario_diffs = []

        # buffers for scenario-avg plots (ABAB => 4 segments)
        buf_win_concat = []    # (4*NW,)
        buf_z_concat = []      # raw z(t) list for global PCA: (Tmix,D) per subject

        # per-subject z splits for later projection & resample
        per_subject_latent = []  # dicts with zA1,zB1,zA2,zB2

        for fold_idx, test_subj in enumerate(subjects):
            set_strict_seed(CONFIG["seed"])

            gt_train_map = CONFIG["GT_BY_ACT"][TRAIN_ACT_ID]
            train_labels = [(s, TRAIN_ACT_ID, gt_train_map[s]) for s in subjects if s != test_subj]
            test_labels  = [(test_subj, TRAIN_ACT_ID, gt_train_map[test_subj])]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            if not test_trials or not train_trials:
                continue

            train_data = trial_list_to_windows(
                train_trials, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                drop_last=CONFIG["drop_last"]
            )

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_data[0]['data'].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for epoch in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval()

            # (A) LOSO on TRAIN_ACT_ID
            item = test_trials[0]
            x_np = item["data"]
            count_gt = float(item["count"])
            count_pred_win, _ = predict_count_by_windowing(
                model, x_np=x_np,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            fold_mae = float(abs(count_pred_win - count_gt))
            loso_results.append(fold_mae)

            # (B) Scenario eval (ABAB)
            mixed_item = build_mixed_abab_trial(test_subj, actA_id, actB_id, CONFIG, full_data)
            if mixed_item is None:
                continue

            x_mix = mixed_item["data"]
            boundaries = mixed_item["boundaries"]  # [b1,b2,b3]
            gt_total = float(mixed_item["count"])
            md = mixed_item["meta_detail"]

            pred_total, _ = predict_count_by_windowing(
                model, x_np=x_mix,
                fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            diff = float(pred_total - gt_total)
            mae = float(abs(diff))
            scenario_diffs.append(diff)

            # full forward for z only (no effK/entropy plots)
            x_tensor = torch.tensor(x_mix, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            with torch.no_grad():
                _, z_m, _, _ = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

            z_td = z_m.squeeze(0).detach().cpu().numpy()  # (Tmix,D)

            b1, b2, b3 = boundaries
            scenario_records.append({
                "subj": test_subj,
                "line": (
                    f"[Scenario-ABAB] {scen_name} | {test_subj} | "
                    f"{md['actA_name']}({md['gtA']:.0f}) -> {md['actB_name']}({md['gtB']:.0f}) -> "
                    f"{md['actA_name']}({md['gtA']:.0f}) -> {md['actB_name']}({md['gtB']:.0f}) | "
                    f"GT_total={gt_total:.0f} | Pred(win)={pred_total:.2f} | Diff={diff:+.2f} | MAE={mae:.2f} | "
                    f"boundaries={boundaries}"
                )
            })

            # ---------------------------
            # buffers for scenario-avg plots (ABAB => 4 segments)
            # ---------------------------
            xA1 = mixed_item["data_A1"]
            xB1 = mixed_item["data_B1"]
            xA2 = mixed_item["data_A2"]
            xB2 = mixed_item["data_B2"]

            # 1) window-level rep_rate curves per segment -> resample -> concat (0~4)
            _, rA1_w = get_window_rate_curve(
                model, xA1, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rB1_w = get_window_rate_curve(
                model, xB1, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rA2_w = get_window_rate_curve(
                model, xA2, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )
            _, rB2_w = get_window_rate_curve(
                model, xB2, fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                device=device, tau=CONFIG.get("tau", 1.0),
                batch_size=CONFIG.get("batch_size", 64)
            )

            rA1_rs = resample_1d(rA1_w, NW)
            rB1_rs = resample_1d(rB1_w, NW)
            rA2_rs = resample_1d(rA2_w, NW)
            rB2_rs = resample_1d(rB2_w, NW)
            buf_win_concat.append(np.concatenate([rA1_rs, rB1_rs, rA2_rs, rB2_rs], axis=0))  # (4*NW,)

            # 2) latent: store raw z split for global PCA projection later
            zA1 = z_td[:b1, :]
            zB1 = z_td[b1:b2, :]
            zA2 = z_td[b2:b3, :]
            zB2 = z_td[b3:, :]
            buf_z_concat.append(z_td)
            per_subject_latent.append({"zA1": zA1, "zB1": zB1, "zA2": zA2, "zB2": zB2})

        # ---------------------------------------------------------
        # Printing
        # ---------------------------------------------------------
        print("-" * 100)
        if len(loso_results) > 0:
            print("TRain 결과 ->")
            print("-" * 100)
            print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
            print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
            print("-" * 100)
        else:
            print("TRain 결과 -> (no folds computed)")
            print("-" * 100)

        subj2line = {r["subj"]: r["line"] for r in scenario_records}
        for s in subjects:
            if s in subj2line:
                print(subj2line[s])

        print("\n" + "-" * 100)
        print(f"시나리오{scen_idx} 결과")
        print("-" * 100)

        if len(scenario_diffs) == 0:
            print(f"[{scen_name}] No samples (all skipped). Fill GT_BY_ACT for required activities.")
        else:
            diffs = np.array(scenario_diffs, dtype=np.float32)
            mae = float(np.mean(np.abs(diffs)))
            rmse = float(np.sqrt(np.mean(diffs ** 2)))
            print(f"[{scen_name}] N={len(diffs):2d} | MAE={mae:.3f} | RMSE={rmse:.3f} | mean(diff)={diffs.mean():+.3f} | std(diff)={diffs.std():.3f}")
        print("-" * 100)

        # ---------------------------------------------------------
        # ✅ Scenario-level mean±std plots (ABAB) — TWO plots only
        # ---------------------------------------------------------
        n_ok = len(buf_win_concat)
        if n_ok == 0:
            print("[Viz Skip] No scenario samples collected for averaging.")
            continue

        # (1) window-level mean±std on normalized axis [0,4]
        W = np.stack(buf_win_concat, axis=0)  # (N, 4*NW)
        win_mean = W.mean(axis=0)
        win_std  = W.std(axis=0)
        win_t = np.linspace(0.0, 4.0, num=4 * NW, dtype=np.float32)

        # (2) latent: global PCA basis from all z
        Z_all = np.concatenate(buf_z_concat, axis=0)  # (sumT, D)
        mu, V2 = global_pca2(Z_all)

        # project each subject, resample each segment to NL then concat -> (4*NL,2)
        lat_list = []
        for d in per_subject_latent:
            zA1 = d["zA1"]; zB1 = d["zB1"]; zA2 = d["zA2"]; zB2 = d["zB2"]
            pcA1 = (zA1 - mu) @ V2.T
            pcB1 = (zB1 - mu) @ V2.T
            pcA2 = (zA2 - mu) @ V2.T
            pcB2 = (zB2 - mu) @ V2.T
            pcA1_rs = resample_2d(pcA1, NL)
            pcB1_rs = resample_2d(pcB1, NL)
            pcA2_rs = resample_2d(pcA2, NL)
            pcB2_rs = resample_2d(pcB2, NL)
            lat_list.append(np.concatenate([pcA1_rs, pcB1_rs, pcA2_rs, pcB2_rs], axis=0))  # (4*NL,2)

        LAT = np.stack(lat_list, axis=0)  # (N,4*NL,2)
        lat_xy_mean = LAT.mean(axis=0)    # (4*NL,2)
        lat_xy_std  = LAT.std(axis=0)     # (4*NL,2) (unused)

        plot_scenario_mean_std_two_plots(
            save_dir=VIZ_DIR,
            scen_name=scen_name,
            win_t=win_t, win_mean=win_mean, win_std=win_std,
            lat_xy_mean=lat_xy_mean, lat_xy_std=lat_xy_std,
            n_subjects=n_ok,
            dpi=600
        )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

1. 시나리오1: Frontal elevation of arms -> Knees bending (ABAB)
----------------------------------------------------------------------------------------------------
TRain 결과 ->
----------------------------------------------------------------------------------------------------
 >>> Final LOSO Result (Average MAE): 2.324
 >>> Standard Deviation: 1.910
----------------------------------------------------------------------------------------------------
[Scenario-ABAB] Frontal elevation of arms -> Knees bending (ABAB) | subject1 | Frontal elevation of arms(20) -> Knees bending(20) -> Frontal elevation of arms(20) -> Knees bending(20) | GT_total=80 | Pred(win)=101.02 | Diff=+21.02 | MAE=21.02 | boundaries=[3072, 6451, 9523]
[Scenario-ABAB] Frontal elevation of arms -> Knees bending (ABAB) | subject2 | Frontal elevation of arms(20) -> Knees bending(21) -> Frontal elevation of arms(20) -> Knee

In [3]:
import os
import glob
import random
import re
import numpy as np
import pandas as pd
import seaborn as sns  # kept
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d  # kept

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# =============================================================================
# Small IO helpers
# =============================================================================
def _ensure_dir(d):
    os.makedirs(d, exist_ok=True)


def _safe_fname(s: str) -> str:
    s = str(s)
    s = re.sub(r"[\\/:*?\"<>|]+", "_", s)
    s = re.sub(r"\s+", " ", s).strip()
    s = s.replace(" ", "_")
    return s[:180]


# =============================================================================
# 1) Strict Seeding
# =============================================================================
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# =============================================================================
# 2) Data Loading
# =============================================================================
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,              # (T, C)
                'count': float(gt_count),      # trial total count
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# =============================================================================
# ✅ Mixed builder helpers
# =============================================================================
def get_single_trial_from_full_data(subj, act_id, gt_count, full_data, target_map, feature_map, normalize=True):
    act_name = target_map.get(act_id)
    feats = feature_map.get(act_id)

    if subj not in full_data or act_name not in full_data[subj]:
        return None

    raw_df = full_data[subj][act_name][feats]
    x = raw_df.values.astype(np.float32)

    if normalize:
        mean = x.mean(axis=0)
        std = x.std(axis=0) + 1e-6
        x = (x - mean) / std

    return {
        "data": x,
        "count": float(gt_count),
        "meta": f"{subj}_{act_name}"
    }


# =============================================================================
# ✅ Mixed A-B builder (kept; not used in pattern runs but left intact)
# =============================================================================
def build_mixed_ab_trial(subj, actA_id, actB_id, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = gtA + gtB

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    boundary = int(xA_raw.shape[0])

    x_mix_raw = np.concatenate([xA_raw, xB_raw], axis=0).astype(np.float32)

    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    xA = (xA_raw - mean) / std
    xB = (xB_raw - mean) / std

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    return {
        "data": x_mix,
        "count": float(gt_total),
        "meta": f"{subj}__{actA_name}__TO__{actB_name}",
        "boundary": boundary,
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
        },
        "data_A": xA,
        "data_B": xB,
        "T_A": int(xA.shape[0]),
        "T_B": int(xB.shape[0]),
    }


# =============================================================================
# ✅ Mixed builder with arbitrary pattern (e.g., "ABAB", "AABB", "ABBA", ... / now AAAAB used)
# =============================================================================
def build_mixed_pattern_trial(subj, actA_id, actB_id, pattern, config, full_data):
    gt_map = config.get("GT_BY_ACT", {})
    if actA_id not in gt_map or actB_id not in gt_map:
        return None
    if subj not in gt_map[actA_id] or subj not in gt_map[actB_id]:
        return None

    pattern = pattern.strip().upper()
    if (len(pattern) < 2) or (not set(pattern).issubset({"A", "B"})):
        raise ValueError(f"pattern must be a string of A/B with len>=2, got: {pattern}")

    gtA = float(gt_map[actA_id][subj])
    gtB = float(gt_map[actB_id][subj])
    gt_total = (pattern.count("A") * gtA) + (pattern.count("B") * gtB)

    A = get_single_trial_from_full_data(
        subj, actA_id, gtA, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    B = get_single_trial_from_full_data(
        subj, actB_id, gtB, full_data,
        config["TARGET_ACTIVITIES_MAP"], config["ACT_FEATURE_MAP"],
        normalize=False
    )
    if A is None or B is None:
        return None

    xA_raw = A["data"]
    xB_raw = B["data"]
    TA = int(xA_raw.shape[0])
    TB = int(xB_raw.shape[0])

    seg_raw = []
    seg_lens = []
    for ch in pattern:
        if ch == "A":
            seg_raw.append(xA_raw)
            seg_lens.append(TA)
        else:
            seg_raw.append(xB_raw)
            seg_lens.append(TB)

    x_mix_raw = np.concatenate(seg_raw, axis=0).astype(np.float32)

    # single z-score on mixed
    mean = x_mix_raw.mean(axis=0)
    std = x_mix_raw.std(axis=0) + 1e-6
    x_mix = (x_mix_raw - mean) / std

    # boundaries: cumulative sums except last
    boundaries = []
    cs = 0
    for L in seg_lens[:-1]:
        cs += int(L)
        boundaries.append(cs)

    # per-segment normalized views (same mean/std)
    segments = []
    for ch in pattern:
        if ch == "A":
            segments.append((xA_raw - mean) / std)
        else:
            segments.append((xB_raw - mean) / std)

    actA_name = config["TARGET_ACTIVITIES_MAP"].get(actA_id, str(actA_id))
    actB_name = config["TARGET_ACTIVITIES_MAP"].get(actB_id, str(actB_id))

    # readable meta
    name_seq = [(actA_name if ch == "A" else actB_name) for ch in pattern]
    meta = f"{subj}__" + "__TO__".join(name_seq)

    return {
        "data": x_mix,
        "count": float(gt_total),
        "meta": meta,
        "pattern": pattern,
        "boundaries": boundaries,
        "meta_detail": {
            "subj": subj,
            "actA_id": actA_id, "actB_id": actB_id,
            "actA_name": actA_name, "actB_name": actB_name,
            "gtA": gtA, "gtB": gtB,
            "gt_total": gt_total,
            "pattern": pattern,
        },
        "segments": segments,     # list of np arrays, len = len(pattern)
        "seg_lens": seg_lens,     # list of ints
        "T_A": TA,
        "T_B": TB,
    }


# =============================================================================
# 2.5) ✅ Windowing
# =============================================================================
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# =============================================================================
# ✅ Helpers for scenario-avg plots
# =============================================================================
def get_window_rate_curve(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]

    if T <= win_len:
        pred_count, rates = predict_count_by_windowing(
            model, x_np, fs, win_sec, stride_sec, device, tau=tau, batch_size=batch_size
        )
        t_cent = np.array([0.5 * (T / float(fs))], dtype=np.float32)
        return t_cent, rates.astype(np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)
            rates.append(r_hat.detach().cpu().numpy())
    rates = np.concatenate(rates, axis=0).astype(np.float32)

    centers = np.array([(st + 0.5 * win_len) / float(fs) for st in starts], dtype=np.float32)
    return centers, rates


def resample_1d(y, new_len):
    y = np.asarray(y, dtype=np.float32)
    if y.size == 0:
        return np.zeros((new_len,), dtype=np.float32)
    if y.size == 1:
        return np.full((new_len,), float(y[0]), dtype=np.float32)
    x_old = np.linspace(0.0, 1.0, num=y.size, dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    return np.interp(x_new, x_old, y).astype(np.float32)


def resample_2d(Y, new_len):
    Y = np.asarray(Y, dtype=np.float32)
    if Y.shape[0] == 0:
        return np.zeros((new_len, Y.shape[1]), dtype=np.float32)
    if Y.shape[0] == 1:
        return np.repeat(Y, repeats=new_len, axis=0).astype(np.float32)
    x_old = np.linspace(0.0, 1.0, num=Y.shape[0], dtype=np.float32)
    x_new = np.linspace(0.0, 1.0, num=new_len, dtype=np.float32)
    out = []
    for d in range(Y.shape[1]):
        out.append(np.interp(x_new, x_old, Y[:, d]))
    return np.stack(out, axis=1).astype(np.float32)


def global_pca2(Z_all):
    """
    Z_all: (N,D)
    returns: mean (D,), V2 (2,D)  such that proj = (Z-mean) @ V2.T
    """
    Z_all = np.asarray(Z_all, dtype=np.float32)
    mu = Z_all.mean(axis=0, keepdims=True)
    Zc = Z_all - mu
    _, _, Vt = np.linalg.svd(Zc, full_matrices=False)
    V2 = Vt[:2].astype(np.float32)  # (2,D)
    return mu.squeeze(0).astype(np.float32), V2


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)
    return float(ent_t.mean())


# =============================================================================
# ✅ NEW: Scenario plot (TWO PLOTS) — "two_plots" style 그대로
# =============================================================================
def plot_scenario_mean_std_two_plots(
    save_dir,
    scen_name,
    win_t, win_mean, win_std,
    lat_xy_mean, lat_xy_std,   # (std는 지금 안 쓰지만 signature 유지)
    n_subjects,
    pattern_len,
    dpi=600,
):
    _ensure_dir(save_dir if save_dir is not None else "scenario_viz")
    out_dir = save_dir if save_dir is not None else "scenario_viz"
    scen_tag = _safe_fname(scen_name)

    # -------------------------
    # (1) Window-level rep_rate mean±std (+ boundary + segment shading)
    # - A/B 글자 없음
    # - figsize/폰트 크게
    # -------------------------
    fig1 = plt.figure(figsize=(26.0, 9.2))
    ax = fig1.gca()

    # segment shading (alternating alpha, subtle)
    for i in range(pattern_len):
        a0 = float(i)
        a1 = float(i + 1)
        alpha = 0.06 if (i % 2 == 0) else 0.03
        ax.axvspan(a0, a1, alpha=alpha, color="k", lw=0)

    ax.plot(win_t, win_mean, linewidth=6.4, label="Window rep_rate (mean)")
    ax.fill_between(win_t, win_mean - win_std, win_mean + win_std, alpha=0.22, label="±1 std")

    # boundaries at integers
    for xline in range(1, pattern_len):
        ax.axvline(float(xline), linestyle="--", linewidth=6.0, alpha=0.95, label=("Boundary" if xline == 1 else None))

    ax.set_xlabel(f"Normalized time (0→{pattern_len})", fontsize=46, labelpad=10)
    ax.set_ylabel("rep_rate (reps/s)", fontsize=46, labelpad=10)
    ax.tick_params(axis="both", labelsize=37)

    ax.legend(fontsize=38, frameon=True, loc="upper left")
    fig1.tight_layout()
    fig1.savefig(os.path.join(out_dir, f"{scen_tag}__1_win_rep_rate.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig1)

    # -------------------------
    # (2) Latent trajectory PCA2
    # - std circle 없음
    # - trajectory line 없음 (scatter only)
    # - square bounds 유지
    # -------------------------
    fig2 = plt.figure(figsize=(8.8, 8.8))
    ax = fig2.gca()

    T = lat_xy_mean.shape[0]
    t_lat = np.linspace(0.0, float(pattern_len), num=T, dtype=np.float32)
    sc = ax.scatter(lat_xy_mean[:, 0], lat_xy_mean[:, 1], c=t_lat, s=22, alpha=0.95)

    ax.set_xlabel("PC1", fontsize=24, labelpad=10)
    ax.set_ylabel("PC2", fontsize=24, labelpad=10)
    ax.tick_params(axis="both", labelsize=18)
    ax.set_aspect("equal", adjustable="box")

    x = lat_xy_mean[:, 0]
    y = lat_xy_mean[:, 1]
    xmin, xmax = float(np.min(x)), float(np.max(x))
    ymin, ymax = float(np.min(y)), float(np.max(y))
    cx = 0.5 * (xmin + xmax)
    cy = 0.5 * (ymin + ymax)
    rx = 0.5 * (xmax - xmin)
    ry = 0.5 * (ymax - ymin)
    r = max(rx, ry) * 1.12 + 1e-6
    ax.set_xlim(cx - r, cx + r)
    ax.set_ylim(cy - r, cy + r)

    cbar = fig2.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(f"normalized time (0→{pattern_len})", fontsize=20, labelpad=10)
    cbar.ax.tick_params(labelsize=16)

    fig2.tight_layout()
    fig2.savefig(os.path.join(out_dir, f"{scen_tag}__2_latent_pca2.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)


# =============================================================================
# 2.8) Dataset / Collate
# =============================================================================
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),
        "mask": torch.stack(masks),
        "count": torch.stack(counts),
        "length": torch.tensor(lengths, dtype=torch.float32),
        "meta": metas
    }


# =============================================================================
# 3) Model
# =============================================================================
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)
        x_hat = self.net(zt)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)
        amp = F.softplus(out[..., 0])
        phase_logits = out[..., 1:]
        phase = F.softmax(phase_logits / tau, dim=-1)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)
        x_hat = self.decoder(z)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# =============================================================================
# 4) Loss utils
# =============================================================================
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)
    se = (x_hat - x) ** 2
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# =============================================================================
# 5) Train
# =============================================================================
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)
        mask = batch["mask"].to(device)
        y_count = batch["count"].to(device)
        length = batch["length"].to(device)

        duration = torch.clamp(length / fs, min=1e-6)
        y_rate = y_count / duration

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# =============================================================================
# 6) Main (Scenario-wise LOSO)  - Pattern runs (AAAAB)
# =============================================================================
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            7: 'Frontal elevation of arms',
            8: 'Knees bending',
            10: 'Jogging',
            11: 'Running',
        },

        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'
                ],
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'
                ],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'
                ],
            10: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'
                 ],
            11: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'
                 ],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        "tau": 1.0,

        "GT_BY_ACT": {
            6: {
                "subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20,
            },
            7: {
                "subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20,
            },
            8: {
                "subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21,
            },
            10: {
                "subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156,
            },
            11: {
                "subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172,
            }
        },

        "PATTERNS": ["AAAAB"],

        "MIX_SCENARIOS": [
            ("Running -> Jogging", 11, 10),
            ("Jogging -> Running", 10, 11),

            ("Running -> Knees bending", 11, 8),
            ("Knees bending -> Running", 8, 11),

            ("Frontal elevation of arms -> Waist bends forward", 7, 6),
            ("Waist bends forward -> Frontal elevation of arms", 6, 7),
        ],
    }

    def build_labels_for_act(act_id, subjects_list, gt_by_act, exclude_subj=None):
        gt_map = gt_by_act.get(act_id, {})
        labels = []
        for s in subjects_list:
            if exclude_subj is not None and s == exclude_subj:
                continue
            if s in gt_map:
                labels.append((s, act_id, gt_map[s]))
        return labels

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    VIZ_DIR = "scenario_viz"
    os.makedirs(VIZ_DIR, exist_ok=True)

    # resample lengths (fixed for averaging)  # per-segment
    NW = 60     # window-rate points per segment
    NL = 200    # latent points per segment

    # =========================================================
    # Scenario-wise execution (scenario x pattern)
    # =========================================================
    scen_global_idx = 0
    for base_idx, (base_name, actA_id, actB_id) in enumerate(CONFIG.get("MIX_SCENARIOS", []), start=1):
        for pattern in CONFIG.get("PATTERNS", ["AAAAB"]):
            scen_global_idx += 1
            scen_name = f"{base_name} | pattern={pattern}"
            TRAIN_ACT_ID = actA_id  # ✅ always train on A

            print("\n" + "=" * 100)
            print(f"1. 시나리오{scen_global_idx}: {scen_name}")
            print("=" * 100)

            loso_results = []
            scenario_records = []
            scenario_diffs = []

            pattern_len = len(pattern)

            # buffers for scenario-avg plots
            buf_win_concat = []    # (pattern_len*NW,)
            buf_z_concat = []      # raw z(t): (Tmix,D) per subject
            per_subject_latent = []  # {"z_segs": [...]}

            for fold_idx, test_subj in enumerate(subjects):
                set_strict_seed(CONFIG["seed"])

                train_labels = build_labels_for_act(TRAIN_ACT_ID, subjects, CONFIG["GT_BY_ACT"], exclude_subj=test_subj)
                test_labels  = build_labels_for_act(TRAIN_ACT_ID, subjects, CONFIG["GT_BY_ACT"], exclude_subj=None)
                test_labels  = [t for t in test_labels if t[0] == test_subj]

                if len(train_labels) == 0 or len(test_labels) == 0:
                    print(f"[Skip Fold] Missing GT for TRAIN_ACT_ID={TRAIN_ACT_ID} ({CONFIG['TARGET_ACTIVITIES_MAP'].get(TRAIN_ACT_ID)}) at {test_subj}")
                    continue

                train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
                test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
                if not test_trials or not train_trials:
                    continue

                train_data = trial_list_to_windows(
                    train_trials, fs=CONFIG["fs"],
                    win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    drop_last=CONFIG["drop_last"]
                )
                if len(train_data) == 0:
                    continue

                g = torch.Generator()
                g.manual_seed(CONFIG["seed"])

                train_loader = DataLoader(
                    TrialDataset(train_data),
                    batch_size=CONFIG["batch_size"],
                    shuffle=True,
                    collate_fn=collate_variable_length,
                    generator=g,
                    num_workers=0
                )

                input_ch = train_data[0]['data'].shape[1]
                model = KAutoCountModel(
                    input_ch=input_ch,
                    hidden_dim=CONFIG["hidden_dim"],
                    latent_dim=CONFIG["latent_dim"],
                    K_max=CONFIG["K_max"]
                ).to(device)

                optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
                scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

                for epoch in range(CONFIG["epochs"]):
                    _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                    scheduler.step()

                model.eval()

                # (A) LOSO on TRAIN_ACT_ID
                item = test_trials[0]
                x_np = item["data"]
                count_gt = float(item["count"])
                count_pred_win, _ = predict_count_by_windowing(
                    model, x_np=x_np,
                    fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    device=device, tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )
                fold_mae = float(abs(count_pred_win - count_gt))
                loso_results.append(fold_mae)

                # (B) Scenario eval (pattern)
                mixed_item = build_mixed_pattern_trial(test_subj, actA_id, actB_id, pattern, CONFIG, full_data)
                if mixed_item is None:
                    print(f"[Skip Mix] Missing GT or data for mix: subj={test_subj}, A={actA_id}, B={actB_id}")
                    continue

                x_mix = mixed_item["data"]
                boundaries = mixed_item["boundaries"]  # length = pattern_len-1
                segments = mixed_item["segments"]      # list length = pattern_len
                gt_total = float(mixed_item["count"])
                md = mixed_item["meta_detail"]

                pred_total, _ = predict_count_by_windowing(
                    model, x_np=x_mix,
                    fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    device=device, tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )

                diff = float(pred_total - gt_total)
                mae = float(abs(diff))
                scenario_diffs.append(diff)

                # optional per-segment A/B sanity (first seg vs last seg)
                predA_seg, _ = predict_count_by_windowing(
                    model, x_np=segments[0],
                    fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    device=device, tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )
                predB_seg, _ = predict_count_by_windowing(
                    model, x_np=segments[-1],
                    fs=CONFIG["fs"], win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                    device=device, tau=CONFIG.get("tau", 1.0),
                    batch_size=CONFIG.get("batch_size", 64)
                )

                # full forward for phase/z
                x_tensor = torch.tensor(x_mix, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
                with torch.no_grad():
                    _, z_m, _, aux_m = model(x_tensor, mask=None, tau=CONFIG.get("tau", 1.0))

                phase_p_m = aux_m["phase_p"].squeeze(0).detach().cpu().numpy()  # (Tmix,K)
                k_hat_m = float(aux_m["k_hat"].item())
                ent_m = compute_phase_entropy_mean(phase_p_m)
                z_td = z_m.squeeze(0).detach().cpu().numpy()  # (Tmix,D)

                # pretty sequence string with GT per segment
                seq_names = []
                seq_gts = []
                for ch in pattern:
                    if ch == "A":
                        seq_names.append(md["actA_name"])
                        seq_gts.append(md["gtA"])
                    else:
                        seq_names.append(md["actB_name"])
                        seq_gts.append(md["gtB"])
                seq_str = " -> ".join([f"{n}({g:.0f})" for n, g in zip(seq_names, seq_gts)])

                scenario_records.append({
                    "subj": test_subj,
                    "line": (
                        f"[Scenario-PAT] {base_name} | pattern={pattern} | {test_subj} | "
                        f"{seq_str} | "
                        f"GT_total={gt_total:.0f} | Pred(win)={pred_total:.2f} | Diff={diff:+.2f} | "
                        f"MAE={mae:.2f} | "
                        f"A_GT={md['gtA']:.0f} | A_pred(seg)={predA_seg:.2f} | "
                        f"B_GT={md['gtB']:.0f} | B_pred(seg)={predB_seg:.2f} | "
                        f"k_hat(full)={k_hat_m:.2f} | ent(full)={ent_m:.3f} | "
                        f"boundaries={boundaries}"
                    )
                })

                # ===========================
                # buffers for scenario-avg plots
                # ===========================

                # 1) window-level rep_rate curves per segment -> resample -> concat (0~pattern_len)
                seg_rate_rs = []
                for seg_np in segments:
                    _, r_w = get_window_rate_curve(
                        model, seg_np, fs=CONFIG["fs"],
                        win_sec=CONFIG["win_sec"], stride_sec=CONFIG["stride_sec"],
                        device=device, tau=CONFIG.get("tau", 1.0),
                        batch_size=CONFIG.get("batch_size", 64)
                    )
                    seg_rate_rs.append(resample_1d(r_w, NW))
                buf_win_concat.append(np.concatenate(seg_rate_rs, axis=0))  # (pattern_len*NW,)

                # 2) latent: store raw z split for global PCA projection later
                z_cuts = [0] + boundaries + [z_td.shape[0]]
                z_segs = [z_td[z_cuts[i]:z_cuts[i + 1], :] for i in range(len(z_cuts) - 1)]
                buf_z_concat.append(z_td)
                per_subject_latent.append({"z_segs": z_segs})

            # ---------------------------------------------------------
            # Printing: Train summary -> subject lines -> scenario summary
            # ---------------------------------------------------------
            print("-" * 100)
            if len(loso_results) > 0:
                print("TRain 결과 ->")
                print("-" * 100)
                print(f" >>> Final LOSO Result (Average MAE): {np.mean(loso_results):.3f}")
                print(f" >>> Standard Deviation: {np.std(loso_results):.3f}")
                print("-" * 100)
            else:
                print("TRain 결과 -> (no folds computed)")
                print("-" * 100)

            subj2line = {r["subj"]: r["line"] for r in scenario_records}
            for s in subjects:
                if s in subj2line:
                    print(subj2line[s])

            print("\n" + "-" * 100)
            print(f"시나리오{scen_global_idx} 결과")
            print("-" * 100)

            if len(scenario_diffs) == 0:
                print(f"[{scen_name}] No samples (all skipped). Fill GT_BY_ACT for required activities.")
            else:
                diffs = np.array(scenario_diffs, dtype=np.float32)
                mae_v = float(np.mean(np.abs(diffs)))
                rmse = float(np.sqrt(np.mean(diffs ** 2)))
                print(f"[{scen_name}] N={len(diffs):2d} | MAE={mae_v:.3f} | RMSE={rmse:.3f} | mean(diff)={diffs.mean():+.3f} | std(diff)={diffs.std():.3f}")
            print("-" * 100)

            # ---------------------------------------------------------
            # Scenario-level mean±std plots (two_plots style)
            # ---------------------------------------------------------
            n_ok = len(buf_win_concat)
            if n_ok == 0:
                print("[Viz Skip] No scenario samples collected for averaging.")
                continue

            # (1) window-level mean±std on normalized axis [0, pattern_len]
            W = np.stack(buf_win_concat, axis=0)  # (N, pattern_len*NW)
            win_mean = W.mean(axis=0)
            win_std  = W.std(axis=0)
            win_t = np.linspace(0.0, float(pattern_len), num=pattern_len * NW, dtype=np.float32)

            # (2) latent: global PCA basis from all z
            Z_all = np.concatenate(buf_z_concat, axis=0)  # (sumT, D)
            mu, V2 = global_pca2(Z_all)

            # project each subject, resample each segment to NL then concat -> (pattern_len*NL,2)
            lat_list = []
            for d in per_subject_latent:
                pcs = []
                for zseg in d["z_segs"]:
                    pc = (zseg - mu) @ V2.T
                    pcs.append(resample_2d(pc, NL))
                lat_list.append(np.concatenate(pcs, axis=0))  # (pattern_len*NL,2)

            LAT = np.stack(lat_list, axis=0)  # (N,pattern_len*NL,2)
            lat_xy_mean = LAT.mean(axis=0)    # (pattern_len*NL,2)
            lat_xy_std  = LAT.std(axis=0)     # (pattern_len*NL,2)

            plot_scenario_mean_std_two_plots(
                save_dir=VIZ_DIR,
                scen_name=scen_name,
                win_t=win_t, win_mean=win_mean, win_std=win_std,
                lat_xy_mean=lat_xy_mean, lat_xy_std=lat_xy_std,
                n_subjects=n_ok,
                pattern_len=pattern_len,
                dpi=600
            )

            print(f"[Saved] {VIZ_DIR}/{_safe_fname(scen_name)}__1_win_rep_rate.png")
            print(f"[Saved] {VIZ_DIR}/{_safe_fname(scen_name)}__2_latent_pca2.png")


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

1. 시나리오1: Running -> Jogging | pattern=AAAAB
----------------------------------------------------------------------------------------------------
TRain 결과 ->
----------------------------------------------------------------------------------------------------
 >>> Final LOSO Result (Average MAE): 15.211
 >>> Standard Deviation: 7.371
----------------------------------------------------------------------------------------------------
[Scenario-PAT] Running -> Jogging | pattern=AAAAB | subject1 | Running(165) -> Running(165) -> Running(165) -> Running(165) -> Jogging(157) | GT_total=817 | Pred(win)=812.40 | Diff=-4.60 | MAE=4.60 | A_GT=165 | A_pred(seg)=165.02 | B_GT=157 | B_pred(seg)=148.30 | k_hat(full)=1.62 | ent(full)=0.317 | boundaries=[3072, 6144, 9216, 12288]
[Scenario-PAT] Running -> Jogging | pattern=AAAAB | subject2 | Running(158) -> Running(158) -> Running(158) -> Running(15