In [1]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            7: 'Frontal elevation of arms',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 6,
        "TEST_ACT_ID": 7,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {6: 21, 7: 20},
            "subject2":  {6: 19, 7: 20},
            "subject3":  {6: 21, 7: 20},
            "subject4":  {6: 20, 7: 20},
            "subject5":  {6: 20, 7: 20},
            "subject6":  {6: 20, 7: 20},
            "subject7":  {6: 20, 7: 20},
            "subject8":  {6: 21, 7: 19},
            "subject9":  {6: 21, 7: 19},
            "subject10": {6: 20, 7: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=127

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act6  |  Test: all subjects x act7 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Frontal elevation of arms | Pred(win)=39.38 / GT=20.00 | MAE=19.38 | k_hat=1.96 | ent=0.728 | win_rate_mean=0.641
[Test 02] subject2_Frontal elevation of arms | Pred(win)=26.87 / GT=20.00 | MAE=6.87 | k_hat=2.65 | ent=0.860 | win_rate_mean=0.404
[Test 03] subject3_Frontal elevation of arms | Pred(win)=36.53 / GT=20.00 | MAE=16.53 | k_hat=2.16 | ent=0.789 | win_rate_mean=0.540
[Test 04] subject4_Frontal elevation of arms | Pred(win)=37.74 / GT=20.00 | MAE=17.74 | k_hat=2.06 | ent=0.741 | win_rate_mean=0.576
[Test 05] subject5_Frontal elevation of arms | Pred(win)=36.86 / GT=20

In [2]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            7: 'Frontal elevation of arms',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 7,
        "TEST_ACT_ID": 6,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {6: 21, 7: 20},
            "subject2":  {6: 19, 7: 20},
            "subject3":  {6: 21, 7: 20},
            "subject4":  {6: 20, 7: 20},
            "subject5":  {6: 20, 7: 20},
            "subject6":  {6: 20, 7: 20},
            "subject7":  {6: 20, 7: 20},
            "subject8":  {6: 21, 7: 19},
            "subject9":  {6: 21, 7: 19},
            "subject10": {6: 20, 7: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=132

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act7  |  Test: all subjects x act6 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Waist bends forward | Pred(win)=21.85 / GT=21.00 | MAE=0.85 | k_hat=1.44 | ent=0.504 | win_rate_mean=0.356
[Test 02] subject2_Waist bends forward | Pred(win)=23.00 / GT=19.00 | MAE=4.00 | k_hat=1.38 | ent=0.469 | win_rate_mean=0.362
[Test 03] subject3_Waist bends forward | Pred(win)=24.15 / GT=21.00 | MAE=3.15 | k_hat=1.43 | ent=0.495 | win_rate_mean=0.374
[Test 04] subject4_Waist bends forward | Pred(win)=22.25 / GT=20.00 | MAE=2.25 | k_hat=1.52 | ent=0.549 | win_rate_mean=0.334
[Test 05] subject5_Waist bends forward | Pred(win)=17.34 / GT=20.00 | MAE=2.66 | k_hat=1.82 | ent

In [3]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            8: 'Knees bending',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 6,
        "TEST_ACT_ID": 8,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {6: 21, 8: 20},
            "subject2":  {6: 19, 8: 21},
            "subject3":  {6: 21, 8: 21},
            "subject4":  {6: 20, 8: 19},
            "subject5":  {6: 20, 8: 20},
            "subject6":  {6: 20, 8: 20},
            "subject7":  {6: 20, 8: 21},
            "subject8":  {6: 21, 8: 21},
            "subject9":  {6: 21, 8: 21},
            "subject10": {6: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=127

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act6  |  Test: all subjects x act8 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Knees bending | Pred(win)=41.27 / GT=20.00 | MAE=21.27 | k_hat=1.91 | ent=0.762 | win_rate_mean=0.611
[Test 02] subject2_Knees bending | Pred(win)=39.64 / GT=21.00 | MAE=18.64 | k_hat=2.12 | ent=0.847 | win_rate_mean=0.578
[Test 03] subject3_Knees bending | Pred(win)=23.01 / GT=21.00 | MAE=2.01 | k_hat=3.00 | ent=1.039 | win_rate_mean=0.362
[Test 04] subject4_Knees bending | Pred(win)=17.42 / GT=19.00 | MAE=1.58 | k_hat=3.15 | ent=1.057 | win_rate_mean=0.279
[Test 05] subject5_Knees bending | Pred(win)=20.05 / GT=20.00 | MAE=0.05 | k_hat=2.94 | ent=0.980 | win_rate_mean=0.369

In [4]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            8: 'Knees bending',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 8,
        "TEST_ACT_ID": 6,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {6: 21, 8: 20},
            "subject2":  {6: 19, 8: 21},
            "subject3":  {6: 21, 8: 21},
            "subject4":  {6: 20, 8: 19},
            "subject5":  {6: 20, 8: 20},
            "subject6":  {6: 20, 8: 20},
            "subject7":  {6: 20, 8: 21},
            "subject8":  {6: 21, 8: 21},
            "subject9":  {6: 21, 8: 21},
            "subject10": {6: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=131

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act8  |  Test: all subjects x act6 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Waist bends forward | Pred(win)=50.75 / GT=21.00 | MAE=29.75 | k_hat=1.73 | ent=0.668 | win_rate_mean=0.826
[Test 02] subject2_Waist bends forward | Pred(win)=30.40 / GT=19.00 | MAE=11.40 | k_hat=2.36 | ent=0.758 | win_rate_mean=0.479
[Test 03] subject3_Waist bends forward | Pred(win)=41.26 / GT=21.00 | MAE=20.26 | k_hat=2.00 | ent=0.729 | win_rate_mean=0.639
[Test 04] subject4_Waist bends forward | Pred(win)=48.15 / GT=20.00 | MAE=28.15 | k_hat=2.01 | ent=0.693 | win_rate_mean=0.723
[Test 05] subject5_Waist bends forward | Pred(win)=26.95 / GT=20.00 | MAE=6.95 | k_hat=2.43 |

In [5]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            7: 'Frontal elevation of arms',
            8: 'Knees bending',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 7,
        "TEST_ACT_ID": 8,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {7: 20, 8: 20},
            "subject2":  {7: 20, 8: 21},
            "subject3":  {7: 20, 8: 21},
            "subject4":  {7: 20, 8: 19},
            "subject5":  {7: 20, 8: 20},
            "subject6":  {7: 20, 8: 20},
            "subject7":  {7: 20, 8: 21},
            "subject8":  {7: 19, 8: 21},
            "subject9":  {7: 19, 8: 21},
            "subject10": {7: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=132

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act7  |  Test: all subjects x act8 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Knees bending | Pred(win)=22.03 / GT=20.00 | MAE=2.03 | k_hat=1.55 | ent=0.595 | win_rate_mean=0.326
[Test 02] subject2_Knees bending | Pred(win)=18.82 / GT=21.00 | MAE=2.18 | k_hat=1.58 | ent=0.605 | win_rate_mean=0.274
[Test 03] subject3_Knees bending | Pred(win)=18.17 / GT=21.00 | MAE=2.83 | k_hat=1.95 | ent=0.755 | win_rate_mean=0.286
[Test 04] subject4_Knees bending | Pred(win)=17.12 / GT=19.00 | MAE=1.88 | k_hat=1.78 | ent=0.666 | win_rate_mean=0.274
[Test 05] subject5_Knees bending | Pred(win)=12.37 / GT=20.00 | MAE=7.63 | k_hat=2.07 | ent=0.767 | win_rate_mean=0.228
[

In [6]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            7: 'Frontal elevation of arms',
            8: 'Knees bending',
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
            7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
            8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 8,
        "TEST_ACT_ID": 7,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {7: 20, 8: 20},
            "subject2":  {7: 20, 8: 21},
            "subject3":  {7: 20, 8: 21},
            "subject4":  {7: 20, 8: 19},
            "subject5":  {7: 20, 8: 20},
            "subject6":  {7: 20, 8: 20},
            "subject7":  {7: 20, 8: 21},
            "subject8":  {7: 19, 8: 21},
            "subject9":  {7: 19, 8: 21},
            "subject10": {7: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=131

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act8  |  Test: all subjects x act7 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Frontal elevation of arms | Pred(win)=46.19 / GT=20.00 | MAE=26.19 | k_hat=1.82 | ent=0.631 | win_rate_mean=0.752
[Test 02] subject2_Frontal elevation of arms | Pred(win)=33.54 / GT=20.00 | MAE=13.54 | k_hat=2.34 | ent=0.636 | win_rate_mean=0.504
[Test 03] subject3_Frontal elevation of arms | Pred(win)=51.67 / GT=20.00 | MAE=31.67 | k_hat=1.93 | ent=0.586 | win_rate_mean=0.765
[Test 04] subject4_Frontal elevation of arms | Pred(win)=53.56 / GT=20.00 | MAE=33.56 | k_hat=1.86 | ent=0.580 | win_rate_mean=0.817
[Test 05] subject5_Frontal elevation of arms | Pred(win)=46.92 / GT=2

In [7]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            8: 'Knees bending',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 8,
        "TEST_ACT_ID": 12,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 8: 20},
            "subject2":  {12: 22, 8: 21},
            "subject3":  {12: 21, 8: 21},
            "subject4":  {12: 21, 8: 19},
            "subject5":  {12: 20, 8: 20},
            "subject6":  {12: 21, 8: 20},
            "subject7":  {12: 19, 8: 21},
            "subject8":  {12: 20, 8: 21},
            "subject9":  {12: 20, 8: 21},
            "subject10": {12: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=131

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act8  |  Test: all subjects x act12 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Jump front & back | Pred(win)=23.14 / GT=20.00 | MAE=3.14 | k_hat=1.72 | ent=0.627 | win_rate_mean=1.076
[Test 02] subject2_Jump front & back | Pred(win)=24.20 / GT=22.00 | MAE=2.20 | k_hat=1.51 | ent=0.542 | win_rate_mean=1.182
[Test 03] subject3_Jump front & back | Pred(win)=27.83 / GT=21.00 | MAE=6.83 | k_hat=1.51 | ent=0.499 | win_rate_mean=1.359
[Test 04] subject4_Jump front & back | Pred(win)=28.02 / GT=21.00 | MAE=7.02 | k_hat=1.47 | ent=0.474 | win_rate_mean=1.368
[Test 05] subject5_Jump front & back | Pred(win)=22.33 / GT=20.00 | MAE=2.33 | k_hat=1.59 | ent=0.643 | 

In [8]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            8: 'Knees bending',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             8: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 12,
        "TEST_ACT_ID": 8,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 8: 20},
            "subject2":  {12: 22, 8: 21},
            "subject3":  {12: 21, 8: 21},
            "subject4":  {12: 21, 8: 19},
            "subject5":  {12: 20, 8: 20},
            "subject6":  {12: 21, 8: 20},
            "subject7":  {12: 19, 8: 21},
            "subject8":  {12: 20, 8: 21},
            "subject9":  {12: 20, 8: 21},
            "subject10": {12: 20, 8: 21},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=40

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act12  |  Test: all subjects x act8 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Knees bending | Pred(win)=73.08 / GT=20.00 | MAE=53.08 | k_hat=2.63 | ent=0.908 | win_rate_mean=1.081
[Test 02] subject2_Knees bending | Pred(win)=60.46 / GT=21.00 | MAE=39.46 | k_hat=2.67 | ent=0.976 | win_rate_mean=0.881
[Test 03] subject3_Knees bending | Pred(win)=60.96 / GT=21.00 | MAE=39.96 | k_hat=2.89 | ent=0.989 | win_rate_mean=0.960
[Test 04] subject4_Knees bending | Pred(win)=56.52 / GT=19.00 | MAE=37.52 | k_hat=2.96 | ent=0.976 | win_rate_mean=0.905
[Test 05] subject5_Knees bending | Pred(win)=45.62 / GT=20.00 | MAE=25.62 | k_hat=3.33 | ent=1.004 | win_rate_mean=0.

In [9]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 6,
        "TEST_ACT_ID": 12,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 6: 21},
            "subject2":  {12: 22, 6: 19},
            "subject3":  {12: 21, 6: 21},
            "subject4":  {12: 21, 6: 20},
            "subject5":  {12: 20, 6: 20},
            "subject6":  {12: 21, 6: 20},
            "subject7":  {12: 19, 6: 20},
            "subject8":  {12: 20, 6: 21},
            "subject9":  {12: 20, 6: 21},
            "subject10": {12: 20, 6: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=127

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act6  |  Test: all subjects x act12 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Jump front & back | Pred(win)=12.31 / GT=20.00 | MAE=7.69 | k_hat=2.11 | ent=0.830 | win_rate_mean=0.573
[Test 02] subject2_Jump front & back | Pred(win)=14.96 / GT=22.00 | MAE=7.04 | k_hat=1.79 | ent=0.720 | win_rate_mean=0.730
[Test 03] subject3_Jump front & back | Pred(win)=17.13 / GT=21.00 | MAE=3.87 | k_hat=1.68 | ent=0.665 | win_rate_mean=0.837
[Test 04] subject4_Jump front & back | Pred(win)=16.84 / GT=21.00 | MAE=4.16 | k_hat=1.64 | ent=0.647 | win_rate_mean=0.822
[Test 05] subject5_Jump front & back | Pred(win)=13.98 / GT=20.00 | MAE=6.02 | k_hat=1.90 | ent=0.782 | 

In [10]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            6: 'Waist bends forward',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             6: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 12,
        "TEST_ACT_ID": 6,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 6: 21},
            "subject2":  {12: 22, 6: 19},
            "subject3":  {12: 21, 6: 21},
            "subject4":  {12: 21, 6: 20},
            "subject5":  {12: 20, 6: 20},
            "subject6":  {12: 21, 6: 20},
            "subject7":  {12: 19, 6: 20},
            "subject8":  {12: 20, 6: 21},
            "subject9":  {12: 20, 6: 21},
            "subject10": {12: 20, 6: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=40

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act12  |  Test: all subjects x act6 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Waist bends forward | Pred(win)=100.28 / GT=21.00 | MAE=79.28 | k_hat=1.79 | ent=0.644 | win_rate_mean=1.632
[Test 02] subject2_Waist bends forward | Pred(win)=56.16 / GT=19.00 | MAE=37.16 | k_hat=2.48 | ent=0.895 | win_rate_mean=0.885
[Test 03] subject3_Waist bends forward | Pred(win)=65.94 / GT=21.00 | MAE=44.94 | k_hat=2.32 | ent=0.829 | win_rate_mean=1.022
[Test 04] subject4_Waist bends forward | Pred(win)=74.95 / GT=20.00 | MAE=54.95 | k_hat=2.33 | ent=0.861 | win_rate_mean=1.126
[Test 05] subject5_Waist bends forward | Pred(win)=60.89 / GT=20.00 | MAE=40.89 | k_hat=2.46

In [11]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            7: 'Frontal elevation of arms',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 7,
        "TEST_ACT_ID": 12,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 7: 20},
            "subject2":  {12: 22, 7: 20},
            "subject3":  {12: 21, 7: 20},
            "subject4":  {12: 21, 7: 20},
            "subject5":  {12: 20, 7: 20},
            "subject6":  {12: 21, 7: 20},
            "subject7":  {12: 19, 7: 20},
            "subject8":  {12: 20, 7: 19},
            "subject9":  {12: 20, 7: 19},
            "subject10": {12: 20, 7: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=132

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act7  |  Test: all subjects x act12 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Jump front & back | Pred(win)=7.45 / GT=20.00 | MAE=12.55 | k_hat=1.93 | ent=0.743 | win_rate_mean=0.347
[Test 02] subject2_Jump front & back | Pred(win)=8.88 / GT=22.00 | MAE=13.12 | k_hat=1.66 | ent=0.608 | win_rate_mean=0.434
[Test 03] subject3_Jump front & back | Pred(win)=10.63 / GT=21.00 | MAE=10.37 | k_hat=1.45 | ent=0.513 | win_rate_mean=0.519
[Test 04] subject4_Jump front & back | Pred(win)=9.70 / GT=21.00 | MAE=11.30 | k_hat=1.47 | ent=0.511 | win_rate_mean=0.474
[Test 05] subject5_Jump front & back | Pred(win)=6.06 / GT=20.00 | MAE=13.94 | k_hat=2.02 | ent=0.801 |

In [12]:
# =========================
# Count-only K-auto (Multi-event) version  (NO manual Pair/lag/overlap/balance)
#
# 핵심 아이디어
# - Micro-event Rate 예측: 모델은 하나의 통합된 속도가 아니라, K_max개의 서로 다른 '작은 단위 동작'의 속도 흐름(r_k(t))을 예측
# - 샘플마다 "rep당 micro-event 개수" k_hat(>=1)을 스스로 추정
# - 우리가 주는 감독은 오직 rep count(=20) -> rep rate만 맞추게
#
# 1) rate head를 amp(t) * softmax(phase) 형태로 바꿔서 K-stream 간 "경쟁"이 생기게 함
# 2) k_hat을 따로 head로 예측하지 않고, phase 사용 분포로부터 effK(=effective K)로 정의 (자동 K)
# 3) phase sparsity/exclusivity를 위해 (a) phase entropy loss, (b) effK usage loss 추가
#
# ✅ Added (ONLY): Time-warp consistency loss
#   - 입력을 시간축으로 stretch/compress (time-warp)해도 "count"는 동일해야 한다는 제약
#   - 모델 출력은 rate이므로, pred_count = rate_hat * duration 을 이용해 count 일관성을 강제
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj]:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,
                'count': float(gt_count),
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) Windowing (added)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=4.0, stride_sec=2.0, drop_last=True):
    """
    trial_list: prepare_trial_list() 결과 (각 item에 'data'(T,C), 'count', 'meta')
    window 라벨은 'trial-level rate'를 이용해서 window count를 만들기:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]           # (T, C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s (trial 전체 평균 rate)

        if T < win_len:
            # 너무 짧으면 그냥 한 개 윈도우로 취급 (pad 없이 variable length로)
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    x_np: (T, C) normalized
    return: pred_count (float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)

    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        # x: (B, C, T)
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        # z: (B, T, D)
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        # z: (B,T,D)
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0  (total micro intensity)
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K), sum=1
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    """
    - outputs K_max micro-event rates r_k(t)
    - predicts k_hat (>=1) per sample
    - rep_rate(t) = sum_k r_k(t) (모든 micro-event 합) / k_hat
    """
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6, k_hidden=64):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp logit bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        # x: (B,T) or (B,T,K)
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)            # (B,T)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)       # (B,)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)  # (B,T,1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)           # (B,K)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        """
        x: (B,C,T), mask: (B,T)
        return:
          avg_rep_rate: (B,)
        """
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p

        # micro-event sum
        micro_rate_t = amp_t

        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        # rep rate
        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6)    # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        # avg rep rate (masked mean)
        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,          # (B,T,K)
            "phase_p": phase_p,              # (B,T,K)
            "phase_logits": phase_logits,    # (B,T,K)
            "micro_rate_t": micro_rate_t,    # (B,T)
            "rep_rate_t": rep_rate_t,        # (B,T)
            "k_hat": k_hat,                  # (B,)
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    """
    v: (B,T) -> L1 smoothness on first difference
    """
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    """
    time-wise exclusivity: 각 t에서 phase가 one-hot에 가까워지게(entropy 최소화)
    phase_p: (B,T,K)
    """
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    """
    overall usage sparsity: time-avg phase usage의 effective-K를 줄이게
    effK = 1 / sum(p_bar^2)  in [1,K]
    """
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


def time_warp_batch(x, mask, length, warp_min=0.7, warp_max=1.3, min_len=8):
    """
    x: (B,C,Tmax) padded
    mask: (B,Tmax)
    length: (B,) float or int, valid length
    return:
      xw: (B,C,Tw_max) padded time-warped
      mw: (B,Tw_max)
      lw: (B,) float (warped valid length)
    """
    B, C, Tmax = x.shape
    device = x.device
    dtype = x.dtype

    xw_list, mw_list, lw_list = [], [], []
    Tw_max = 0

    # per-sample warp (tempo 변화)
    for b in range(B):
        Tb = int(length[b].item())
        Tb = max(Tb, 1)

        xb = x[b:b+1, :, :Tb]  # (1,C,Tb)

        # warp factor
        w = float(torch.empty(1, device=device).uniform_(warp_min, warp_max).item())
        Tw = int(round(Tb * w))
        Tw = max(Tw, min_len)

        # resample to Tw (linear)
        xbw = F.interpolate(xb, size=Tw, mode="linear", align_corners=False)  # (1,C,Tw)

        mw = torch.ones((Tw,), device=device, dtype=mask.dtype)

        xw_list.append(xbw.squeeze(0))  # (C,Tw)
        mw_list.append(mw)              # (Tw,)
        lw_list.append(float(Tw))

        if Tw > Tw_max:
            Tw_max = Tw

    # pad to Tw_max
    xw = torch.zeros((B, C, Tw_max), device=device, dtype=dtype)
    mw = torch.zeros((B, Tw_max), device=device, dtype=mask.dtype)
    for b in range(B):
        xbw = xw_list[b]
        m_b = mw_list[b]
        Tw = xbw.shape[1]
        xw[b, :, :Tw] = xbw
        mw[b, :Tw] = m_b

    lw = torch.tensor(lw_list, device=device, dtype=length.dtype)
    return xw, mw, lw


# ---------------------------------------------------------------------
# 5) Train  (ONLY MODIFIED PART: remove shift, add time-warp consistency)
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk', 'loss_warp',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    # ✅ Added: time-warp consistency weight & range
    lam_warp = config.get("lambda_warp", 0.0)
    warp_min = config.get("warp_min", 0.7)
    warp_max = config.get("warp_max", 1.3)
    warp_detach_ref = bool(config.get("warp_detach_ref", True))

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        # (1) 속도 맞추기 (MSE Loss)
        loss_rate = F.mse_loss(rate_hat, y_rate)

        # (2) recon
        loss_recon = masked_recon_mse(x_hat, x, mask)

        # (3) smoothness (rep_rate_t 기준)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)

        # (4) phase exclusivity (entropy)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)

        # (5) effective-K usage (overall)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        # (6) ✅ Time-warp consistency loss (count-level)
        #     pred_count = rate_hat * duration 이 time-warp 후에도 동일해야 함
        if lam_warp > 0.0:
            xw, mw, lw = time_warp_batch(
                x=x, mask=mask, length=length,
                warp_min=warp_min, warp_max=warp_max, min_len=8
            )
            dur_w = torch.clamp(lw / fs, min=1e-6)

            rate_hat_w, _, _, _ = model(xw, mw, tau=tau)

            pred_count = rate_hat * duration
            pred_count_w = rate_hat_w * dur_w

            if warp_detach_ref:
                pred_count_ref = pred_count.detach()
            else:
                pred_count_ref = pred_count

            loss_warp = F.l1_loss(pred_count_w, pred_count_ref)
        else:
            loss_warp = torch.zeros((), device=device)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk
                + lam_warp * loss_warp)

        loss.backward()
        optimizer.step()

        # MAE on count
        count_hat = rate_hat * duration  # 예측 속도 * 시간 = 예측 개수
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['loss_warp'] += float(loss_warp.item())
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = len(loader)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Visualization helpers (subject-wise subplot)
# ---------------------------------------------------------------------
def _smooth_1d(y, sigma=2.0):
    """
    지저분한 노이즈를 다듬어서 시각화
    """
    y = np.asarray(y, dtype=np.float32)
    return gaussian_filter1d(y, sigma=sigma)


def compute_phase_entropy_mean(phase_p_np, eps=1e-8):
    """
    phase_p_np: (T,K) numpy
    return: time-avg entropy (scalar)
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    ent_t = -(phase_p_np * np.log(phase_p_np + eps)).sum(axis=1)  # (T,)
    return float(ent_t.mean())


def downsample_time_axis(arr, max_T=2000):
    """
    arr: (T, ...) numpy
    너무 길면 시각화가 깨지므로 T를 max_T로 downsample
    return: arr_ds, idx (원래 시간 인덱스)
    """
    T = arr.shape[0]
    if T <= max_T:
        idx = np.arange(T)
        return arr, idx
    idx = np.linspace(0, T - 1, max_T).astype(int)
    return arr[idx], idx


def plot_phase_heatmap_and_dominant(
    phase_p_np,
    fs,
    title="phase_p heatmap + dominant phase",
    max_T=2000
):
    """
    phase_p_np: (T, K) numpy array
    fs: sampling rate
    max_T: 시각화 다운샘플링 길이
    """
    phase_p_np = np.asarray(phase_p_np, dtype=np.float32)
    assert phase_p_np.ndim == 2, f"phase_p_np must be (T,K), got {phase_p_np.shape}"

    # (1) downsample for visualization
    phase_ds, idx = downsample_time_axis(phase_p_np, max_T=max_T)  # (T',K)
    Tds, K = phase_ds.shape
    t_sec = idx / float(fs)

    # (2) dominant phase (argmax)
    dom = np.argmax(phase_ds, axis=1)  # (T',)

    # (3) plot
    fig = plt.figure(figsize=(30, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.25)

    # --- Heatmap (top) ---
    ax0 = fig.add_subplot(gs[0, 0])
    im = ax0.imshow(
        phase_ds.T,                 # (K, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, K]   # x=time, y=phase index range
    )
    ax0.set_title(title, fontsize=24, pad=10)
    ax0.set_ylabel("Phase k", fontsize=18)
    ax0.set_xlabel("Time (sec)", fontsize=18)
    cbar = fig.colorbar(im, ax=ax0, fraction=0.015, pad=0.01)
    cbar.set_label("phase_p(t,k)", fontsize=14)

    # --- Dominant phase timeline (bottom) ---
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax1.imshow(
        dom[None, :],               # (1, T')
        aspect="auto",
        origin="lower",
        interpolation="nearest",
        extent=[t_sec[0], t_sec[-1], 0, 1]
    )
    ax1.set_yticks([])
    ax1.set_ylabel("dominant", fontsize=14)
    ax1.set_xlabel("Time (sec)", fontsize=18)

    plt.tight_layout()
    plt.show()


def plot_folds_test_subplot(viz_cache, fs, title="Fold-wise TEST visualization (only test_subj)"):
    """
    viz_cache: list of dict
      each dict contains:
        - 'fold', 'test_subj', 't', 'rep_rate', 'gt', 'pred', 'diff', 'k_hat', 'entropy'
    """
    if viz_cache is None or len(viz_cache) == 0:
        print("[plot_folds_test_subplot] viz_cache is empty")
        return

    sns.set_theme(style="whitegrid", context="notebook", font_scale=2.0)
    colors = sns.color_palette("muted")
    c_rate = colors[0]
    c_count = colors[1]

    n = len(viz_cache)
    fig, axes = plt.subplots(n, 1, figsize=(36, 9 * n), sharex=False)
    if n == 1:
        axes = [axes]
    axes = np.array(axes).flatten()

    fig.suptitle(title, fontsize=40, y=0.995)

    for i, item in enumerate(viz_cache):
        ax = axes[i]

        t = item["t"]
        rep_rate = item["rep_rate"]
        gt_count = item["gt"]
        pred_count = item["pred"]
        diff = item["diff"]
        k_hat = item["k_hat"]
        entropy = item["entropy"]
        test_subj = item["test_subj"]
        fold = item["fold"]

        rep_s = _smooth_1d(rep_rate, sigma=2.0)
        cum = np.cumsum(rep_rate) / fs

        # 왼쪽: rep_rate
        ax.plot(t, rep_s, color=c_rate, linewidth=2.5, alpha=0.9)
        ax.fill_between(t, rep_s, color=c_rate, alpha=0.15)
        ax.set_ylabel("Rep Rate (reps/s)", color=c_rate, fontweight='bold', fontsize=24)
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.tick_params(axis='both', which='major', labelsize=20)

        # 오른쪽: cumulative count
        ax2 = ax.twinx()
        ax2.plot(t, cum, color=c_count, linewidth=3.5, alpha=1.0)
        ax2.axhline(gt_count, linestyle=":", alpha=0.7)
        ax2.set_ylabel("Count", color=c_count, fontweight='bold', fontsize=24)
        ax2.tick_params(axis='y', labelcolor=c_count, labelsize=20)
        ax2.grid(False)

        ax.set_title(
            f"Fold {fold:2d} | Test: {test_subj} | Pred {pred_count:.2f} / GT {gt_count:.0f} (Diff {diff:+.2f})\n"
            f"k_hat={k_hat:.2f} | phase_entropy={entropy:.3f}",
            fontsize=34, pad=10
        )
        ax.set_xlabel("Time (sec)", fontweight='bold', fontsize=24)

    plt.tight_layout(rect=[0, 0, 1, 0.985])
    plt.subplots_adjust(hspace=0.5)
    plt.show()


# ---------------------------------------------------------------------
# 7) Main (A1: Unseen Activity only, WITH windowing)
# ---------------------------------------------------------------------
def build_label_tuples_from_table(subjects, act_id, count_table):
    """
    count_table: dict like { "subject1": {6:21, 7:20}, ... }
    returns: list of (subj, act_id, gt_count)
    """
    labels = []
    for s in subjects:
        if s not in count_table:
            continue
        if act_id not in count_table[s]:
            continue
        labels.append((s, act_id, float(count_table[s][act_id])))
    return labels


def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],
        "TARGET_ACTIVITIES_MAP": {
            7: 'Frontal elevation of arms',
            12: 'Jump front & back'
        },

        # ✅ train/test 모두 C를 동일하게 유지해야 같은 모델로 평가 가능
        "ACT_FEATURE_MAP": {
             7: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z'],
             12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z']
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0005,

        # ✅ Added: Time-warp consistency (ONLY NEW)
        "lambda_warp": 0.2,      # 시작은 1e-2~5e-2 추천
        "warp_min": 0.5,          # 느리게(늘리기)
        "warp_max": 1.5,          # 빠르게(줄이기)
        "warp_detach_ref": True,  # warped 쪽만 원본 count로 맞추게(안정)

        # temperature (phase 경쟁 강도)
        "tau": 1.0,

        # A1 setting
        "TRAIN_ACT_ID": 12,
        "TEST_ACT_ID": 7,

        # -------------------------
        # ✅ Windowing (added)
        # -------------------------
        "USE_WINDOWING": True,
        "WIN_SEC": 8.0,
        "STRIDE_SEC": 4.0,
        "DROP_LAST": True,

        # ✅ dict 형태로 고쳐야 함
        "COUNT_TABLE": {
            "subject1":  {12: 20, 7: 20},
            "subject2":  {12: 22, 7: 20},
            "subject3":  {12: 21, 7: 20},
            "subject4":  {12: 21, 7: 20},
            "subject5":  {12: 20, 7: 20},
            "subject6":  {12: 21, 7: 20},
            "subject7":  {12: 19, 7: 20},
            "subject8":  {12: 20, 7: 19},
            "subject9":  {12: 20, 7: 19},
            "subject10": {12: 20, 7: 20},
        },
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(
        CONFIG["data_dir"],
        CONFIG["TARGET_ACTIVITIES_MAP"],
        CONFIG["COLUMN_NAMES"]
    )
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # -------------------------
    # A1 split (NO LOSO)
    # Train: all subjects x TRAIN_ACT_ID
    # Test : all subjects x TEST_ACT_ID (unseen activity)
    # -------------------------
    train_labels = build_label_tuples_from_table(subjects, CONFIG["TRAIN_ACT_ID"], CONFIG["COUNT_TABLE"])
    test_labels  = build_label_tuples_from_table(subjects, CONFIG["TEST_ACT_ID"],  CONFIG["COUNT_TABLE"])

    if len(train_labels) == 0:
        print("[Error] No train labels. Check COUNT_TABLE / TRAIN_ACT_ID.")
        return
    if len(test_labels) == 0:
        print("[Error] No test labels. Check COUNT_TABLE / TEST_ACT_ID.")
        return

    train_data = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
    test_data  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

    if len(train_data) == 0:
        print("[Error] train_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return
    if len(test_data) == 0:
        print("[Error] test_data empty. (missing files or ACT_FEATURE_MAP mismatch)")
        return

    # -------------------------
    # ✅ Windowing 적용 (Train에만)
    # -------------------------
    if CONFIG.get("USE_WINDOWING", False):
        train_windows = trial_list_to_windows(
            train_data,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            drop_last=CONFIG["DROP_LAST"],
        )
        print(f"[Windowing] train trials={len(train_data)} -> train windows={len(train_windows)}")
        train_data_for_loader = train_windows
    else:
        train_data_for_loader = train_data

    train_loader = DataLoader(
        TrialDataset(train_data_for_loader),
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_variable_length,
        num_workers=0
    )

    input_ch = train_data[0]['data'].shape[1]
    model = KAutoCountModel(
        input_ch=input_ch,
        hidden_dim=CONFIG["hidden_dim"],
        latent_dim=CONFIG["latent_dim"],
        K_max=CONFIG["K_max"]
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    print("\n" + "-"*80)
    print(f" >>> A1 Train: all subjects x act{CONFIG['TRAIN_ACT_ID']}  |  Test: all subjects x act{CONFIG['TEST_ACT_ID']} (unseen activity)")
    print("-"*80)

    # ---- Train ----
    for epoch in range(CONFIG["epochs"]):
        _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
        scheduler.step()

    # ---- Test (per-subject) ----
    model.eval()
    maes = []
    viz_cache = []

    for idx, item in enumerate(test_data):
        x_np = item["data"]  # (T,C)
        gt_count = float(item["count"])

        # ✅ windowing으로 pred_count 계산 (Test)
        pred_count, win_rates = predict_count_by_windowing(
            model,
            x_np=x_np,
            fs=CONFIG["fs"],
            win_sec=CONFIG["WIN_SEC"],
            stride_sec=CONFIG["STRIDE_SEC"],
            device=device,
            tau=CONFIG["tau"],
            batch_size=CONFIG["batch_size"]
        )

        mae = abs(pred_count - gt_count)
        maes.append(mae)

        # 시각화용(원래처럼 full forward 1회)
        with torch.no_grad():
            x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)
            rate_hat_full, _, _, aux = model(x_tensor, mask=None, tau=CONFIG["tau"])

            phase_p = aux["phase_p"].squeeze(0).detach().cpu().numpy()
            k_hat = float(aux["k_hat"].item())
            ent = compute_phase_entropy_mean(phase_p)
            rep_rate = aux["rep_rate_t"].squeeze(0).detach().cpu().numpy()
            T = rep_rate.shape[0]
            t = np.arange(T) / float(CONFIG["fs"])

        viz_cache.append({
            "fold": idx + 1,   # 그냥 인덱스
            "test_subj": item["meta"],  # subject+act name
            "t": t,
            "rep_rate": rep_rate,
            "gt": gt_count,
            "pred": float(pred_count),   # ✅ windowing pred
            "diff": float(pred_count - gt_count),
            "k_hat": k_hat,
            "entropy": ent,
            "phase_p": phase_p,
        })

        print(
            f"[Test {idx+1:02d}] {item['meta']} | Pred(win)={pred_count:.2f} / GT={gt_count:.2f} | "
            f"MAE={mae:.2f} | k_hat={k_hat:.2f} | ent={ent:.3f} | win_rate_mean={win_rates.mean():.3f}"
        )

    print("-"*80)
    print(f" >>> A1 Final MAE mean: {np.mean(maes):.3f}")
    print(f" >>> A1 Final MAE std : {np.std(maes):.3f}")
    print("-"*80)

    # 시각화(원하면 유지)
    # plot_folds_test_subplot(
    #     viz_cache,
    #     fs=CONFIG["fs"],
    #     title=f"A1 TEST visualization | Train act{CONFIG['TRAIN_ACT_ID']} -> Test act{CONFIG['TEST_ACT_ID']}"
    # )

    # for item in viz_cache:
    #     plot_phase_heatmap_and_dominant(
    #         item["phase_p"],
    #         fs=CONFIG["fs"],
    #         title=f"{item['test_subj']} | k_hat={item['k_hat']:.2f} | ent={item['entropy']:.3f}",
    #         max_T=2000
    #     )


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...
[Windowing] train trials=10 -> train windows=40

--------------------------------------------------------------------------------
 >>> A1 Train: all subjects x act12  |  Test: all subjects x act7 (unseen activity)
--------------------------------------------------------------------------------
[Test 01] subject1_Frontal elevation of arms | Pred(win)=58.06 / GT=20.00 | MAE=38.06 | k_hat=2.77 | ent=0.763 | win_rate_mean=0.945
[Test 02] subject2_Frontal elevation of arms | Pred(win)=52.44 / GT=20.00 | MAE=32.44 | k_hat=3.59 | ent=0.812 | win_rate_mean=0.788
[Test 03] subject3_Frontal elevation of arms | Pred(win)=66.63 / GT=20.00 | MAE=46.63 | k_hat=2.79 | ent=0.755 | win_rate_mean=0.986
[Test 04] subject4_Frontal elevation of arms | Pred(win)=68.65 / GT=20.00 | MAE=48.65 | k_hat=2.64 | ent=0.756 | win_rate_mean=1.047
[Test 05] subject5_Frontal elevation of arms | Pred(win)=57.71 / GT=2