In [1]:
# =========================
# Count-only K-auto (Multi-event) + Windowing version
# + ✅ Robustness: Clean-train → Corrupted-test (test only)
#
# Protocol (reflected in code):
# - TRAIN: always clean (trial -> sliding windows, window label from trial-average rate)
# - TEST : trial 그대로 두고, "정규화된(clean) 신호"에 corruption 적용 후 평가
# - TEST pred: windowing inference (window rate 평균 → trial count)
# - clean→clean baseline은 측정하지 않도록 옵션(INCLUDE_BASELINE_LEVELS=False) 제공
#
# Stressors:
# 1) Noise: Gaussian / Bias drift / Spike
# 2) Sampling rate: fs 50→25→12.5 (anti-aliasing low-pass → decimate)
# 3) Axis dropout cases: Acc z-only / Gyro 제거 / Acc only / Gyro only (차원 유지: 0-masking)
# 4) Window mismatch: test win_sec ∈ {2,3,4,6,8}, stride=win/2 (50% overlap)
#
# Metrics:
# - MAE, MAPE (+ Worst-case)
# - Degradation curves saved as PNG
# =========================

import os
import glob
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# scipy signal (for LPF)
try:
    from scipy.signal import butter, filtfilt
    _SCIPY_SIGNAL_OK = True
except Exception:
    _SCIPY_SIGNAL_OK = False


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 1.5) Build ALL_LABELS from COUNT_TABLE (minimal add)
# ---------------------------------------------------------------------
def build_all_labels_from_count_table(count_table, subjects, target_map, feature_map, allow_only_mapped=True):
    """
    COUNT_TABLE: {act_id: {subjectX: gt_count}}
    Returns ALL_LABELS list of (subj, act_id, gt)
    - allow_only_mapped=True: target_map & feature_map에 존재하는 act_id만 포함 (10/11 같은 미등록 act 자동 제외)
    """
    labels = []
    for act_id in sorted(count_table.keys()):
        if allow_only_mapped and ((act_id not in target_map) or (act_id not in feature_map)):
            continue
        subj_dict = count_table.get(act_id, {})
        for subj in subjects:
            if subj in subj_dict:
                labels.append((subj, int(act_id), int(subj_dict[subj])))
    return labels


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            pass

    return full_dataset


def prepare_trial_list(label_config, full_data, target_map, feature_map):
    """
    NOTE: per-trial z-score normalization (clean 기준은 각 trial 자체)
    """
    trial_list = []
    for subj, act_id, gt_count in label_config:
        act_name = target_map.get(act_id)
        feats = feature_map.get(act_id)

        if subj in full_data and act_name in full_data[subj] and feats is not None:
            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            # Z-score 정규화 (표준화) 평균=0, std=1  (clean signal로 계산)
            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trial_list.append({
                'data': norm_np,              # (T, C) normalized
                'count': float(gt_count),      # trial total count
                'meta': f"{subj}_{act_name}"
            })
        else:
            print(f"[Skip] Missing data for {subj} - {act_name}")

    return trial_list


# ---------------------------------------------------------------------
# 2.5) ✅ Windowing (TRAIN)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    """
    TRAIN only: trial -> sliding windows
    window label from trial-average rate:
      rate_trial = count_total / total_duration
      count_window = rate_trial * window_duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


# ---------------------------------------------------------------------
# 2.6) ✅ Windowing inference (TEST)
# ---------------------------------------------------------------------
def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    TEST: trial -> sliding windows inference -> window rate mean -> total count
    x_np: (T,C) numpy (already normalized, then corrupted)
    return: pred_count(float), window_rates(np.ndarray)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    # short trial -> 1 forward
    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        pred_count = float(rate_hat.item() * total_dur)
        return pred_count, np.array([float(rate_hat.item())], dtype=np.float32)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count), rates


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),         # (B, C, T_max)
        "mask": torch.stack(masks),               # (B, T_max)
        "count": torch.stack(counts),             # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                     # (B,T,1+K)
        amp = F.softplus(out[..., 0])         # (B,T) >=0
        phase_logits = out[..., 1:]           # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1)  # (B,T,K)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)  # amp bias만 -2

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)
        rates_k_t = amp_t.unsqueeze(-1) * phase_p  # (B,T,K)

        micro_rate_t = amp_t  # (B,T)
        p_bar = self._masked_mean_time(phase_p, mask)           # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)          # (B,) in [1,K]

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "rates_k_t": rates_k_t,
            "phase_p": phase_p,
            "phase_logits": phase_logits,
            "micro_rate_t": micro_rate_t,
            "rep_rate_t": rep_rate_t,
            "k_hat": k_hat,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean(), effK.detach()


# ---------------------------------------------------------------------
# 5) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()
    stats = {k: 0.0 for k in [
        'loss', 'loss_rate', 'loss_recon', 'loss_smooth', 'loss_phase_ent', 'loss_effk',
        'mae_count'
    ]}

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.005)

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()
        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk, _ = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()

        count_hat = rate_hat * duration
        stats['loss'] += loss.item()
        stats['loss_rate'] += loss_rate.item()
        stats['loss_recon'] += loss_recon.item()
        stats['loss_smooth'] += loss_smooth.item()
        stats['loss_phase_ent'] += loss_phase_ent.item()
        stats['loss_effk'] += loss_effk.item()
        stats['mae_count'] += torch.abs(count_hat - y_count).mean().item()

    n = max(len(loader), 1)
    return {k: v / n for k, v in stats.items()}


# ---------------------------------------------------------------------
# 6) Robustness: Corruption functions (apply AFTER normalization)
# ---------------------------------------------------------------------
def _rng_from_meta(seed: int, meta: str):
    # deterministic per (seed, meta)
    h = abs(hash((seed, meta))) % (2**32 - 1)
    return np.random.default_rng(h)


def add_gaussian_noise(x, sigma, rng):
    return (x + rng.normal(0.0, sigma, size=x.shape)).astype(np.float32)


def add_bias_drift(x, fs, amplitude, rng, freq_hz=0.2):
    # common low-freq drift
    T, C = x.shape
    t = np.arange(T, dtype=np.float32) / float(fs)
    phase0 = float(rng.uniform(0, 2*np.pi))
    drift = amplitude * np.sin(2*np.pi*freq_hz*t + phase0)  # (T,)
    drift = drift[:, None]  # (T,1)
    return (x + drift).astype(np.float32)


def add_spikes(x, p, A, rng):
    T, C = x.shape
    n_spike = int(round(p * T))
    if n_spike <= 0:
        return x.astype(np.float32)

    idx = rng.choice(T, size=min(n_spike, T), replace=False)
    spike = np.zeros_like(x, dtype=np.float32)
    for i in idx:
        signs = rng.choice([-1.0, 1.0], size=C).astype(np.float32)
        spike[i, :] += (A * signs)
    return (x + spike).astype(np.float32)


def lowpass_then_decimate(x, fs_orig, fs_new, order=4):
    """
    Anti-aliasing low-pass → decimate (integer factor only)
    Operates on normalized signal (per protocol).
    """
    if not _SCIPY_SIGNAL_OK:
        raise RuntimeError("scipy.signal not available. Please install SciPy or switch to an FIR fallback.")

    ratio = fs_orig / fs_new
    if abs(ratio - round(ratio)) > 1e-6:
        raise ValueError(f"fs_orig/fs_new must be integer. Got {fs_orig}/{fs_new}={ratio}")
    factor = int(round(ratio))
    if factor < 1:
        raise ValueError("Decimation factor must be >=1")

    if factor == 1:
        return x.astype(np.float32)

    # cutoff slightly below new Nyquist
    nyq_orig = fs_orig / 2.0
    cutoff_hz = 0.45 * fs_new  # 0.9 * (fs_new/2)
    Wn = cutoff_hz / nyq_orig
    Wn = min(max(Wn, 1e-4), 0.99)

    b, a = butter(order, Wn, btype='low')

    x_f = np.zeros_like(x, dtype=np.float32)
    for c in range(x.shape[1]):
        x_f[:, c] = filtfilt(b, a, x[:, c].astype(np.float64)).astype(np.float32)

    return x_f[::factor].astype(np.float32)


def apply_axis_dropout_case(x, case_name):
    """
    Keep channel dimension (0-masking).
    Assumes ACT_FEATURE_MAP uses:
    [chest acc(3), ankle acc(3), ankle gyro(3), arm acc(3), arm gyro(3)] = 15 ch
    """
    x = x.astype(np.float32).copy()
    T, C = x.shape
    if C != 15:
        print(f"[Warn] Axis dropout assumes C=15, got C={C}. Skip masking.")
        return x

    chest_acc = [0, 1, 2]
    ankle_acc = [3, 4, 5]
    ankle_gyro = [6, 7, 8]
    arm_acc   = [9, 10, 11]
    arm_gyro  = [12, 13, 14]

    acc_all = chest_acc + ankle_acc + arm_acc
    gyro_all = ankle_gyro + arm_gyro

    if case_name == "acc_z_only":
        # keep z of each acc triad: chest z=2, ankle z=5, arm z=11
        keep = [2, 5, 11]
        drop = [i for i in acc_all if i not in keep]
        x[:, drop] = 0.0
        return x

    if case_name == "gyro_drop":
        x[:, gyro_all] = 0.0
        return x

    if case_name == "acc_only":
        x[:, gyro_all] = 0.0
        return x

    if case_name == "gyro_only":
        x[:, acc_all] = 0.0
        return x

    raise ValueError(f"Unknown axis dropout case: {case_name}")


# ---------------------------------------------------------------------
# 7) Robustness evaluation helpers
# ---------------------------------------------------------------------
def compute_errors(pred, gt, eps=1e-6):
    abs_err = float(abs(pred - gt))
    mape = abs_err / float(max(gt, eps))
    return abs_err, float(mape)


def eval_one_trial_windowing(model, x_np, gt_count, fs_eval, win_sec_eval, device, tau, batch_size):
    stride_sec_eval = win_sec_eval / 2.0  # always 50% overlap
    pred_count, _ = predict_count_by_windowing(
        model,
        x_np=x_np,
        fs=fs_eval,
        win_sec=win_sec_eval,
        stride_sec=stride_sec_eval,
        device=device,
        tau=tau,
        batch_size=batch_size
    )
    mae, mape = compute_errors(pred_count, gt_count)
    return pred_count, mae, mape


def save_degradation_plot(df_agg, out_path, title, x_col="level", y_col="mae_mean", yerr_col="mae_std"):
    x = df_agg[x_col].values
    y = df_agg[y_col].values
    yerr = df_agg[yerr_col].values if yerr_col in df_agg.columns else None

    plt.figure(figsize=(10, 6))
    if yerr is not None:
        plt.errorbar(x, y, yerr=yerr, marker='o', linestyle='-')
    else:
        plt.plot(x, y, marker='o', linestyle='-')
    plt.title(title)
    plt.xlabel(x_col)
    plt.ylabel("MAE")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# ---------------------------------------------------------------------
# 8) Main (LOSO + Robustness sweeps)
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",
        "out_dir": "./robustness_results",
        "INCLUDE_BASELINE_LEVELS": False,  # clean→clean baseline은 이미 있다 했으니 False 권장

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        # ✅ You will extend/replace these with your 10 activities and GT counts
        "TARGET_ACTIVITIES_MAP": {
            6:  'Waist bends forward',
            7:  'Frontal elevation of arms',
            8:  'Knees bending',
            10:  'Jogging',
            11: 'Running',
            12: 'Jump front & back',
        },

        # ✅ Using 15 channels (acc+gyro only) as in your current code
        "ACT_FEATURE_MAP": {
            6:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            7:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            8:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            10:  ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            11: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
            12: ['acc_chest_x','acc_chest_y','acc_chest_z','acc_ankle_x','acc_ankle_y','acc_ankle_z',
                 'gyro_ankle_x','gyro_ankle_y','gyro_ankle_z','acc_arm_x','acc_arm_y','acc_arm_z',
                 'gyro_arm_x','gyro_arm_y','gyro_arm_z'],
        },

        # Training Params
        "epochs": 100,
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params (TRAIN fixed 8s/4s)
        "win_sec_train": 8.0,
        "stride_sec_train": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Loss Weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        # temperature
        "tau": 1.0,

        # Count-only labels (your GT table)
        "COUNT_TABLE": {
            6: {
                "subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20,
            },
            7: {
                "subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20,
            },
            8: {
                "subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21,
            },
            10: {
                "subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156,
            },
            11: {
                "subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172,
            },
            12: {
                "subject1": 20, "subject2": 22, "subject3": 21, "subject4": 21, "subject5": 20,
                "subject6": 21, "subject7": 19, "subject8": 20, "subject9": 20, "subject10": 20,
            },
        }
    }

    os.makedirs(CONFIG["out_dir"], exist_ok=True)

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        return

    subjects = [f"subject{i}" for i in range(1, 11)]

    # ✅ build ALL_LABELS from COUNT_TABLE (only mapped activities to avoid 10/11 unless you add maps)
    CONFIG["ALL_LABELS"] = build_all_labels_from_count_table(
        CONFIG["COUNT_TABLE"],
        subjects,
        target_map=CONFIG["TARGET_ACTIVITIES_MAP"],
        feature_map=CONFIG["ACT_FEATURE_MAP"],
        allow_only_mapped=True
    )

    # -------------------------
    # Robustness sweep configs
    # -------------------------
    noise_gauss_levels = [0.02, 0.05, 0.1, 0.2]
    noise_drift_levels = [0.05, 0.1, 0.2]
    noise_spike_p_levels = [0.001, 0.005, 0.01]
    spike_A = 3.0

    fs_levels = [50.0, 25.0, 12.5]
    if not CONFIG["INCLUDE_BASELINE_LEVELS"]:
        fs_levels = [25.0, 12.5]  # exclude 50Hz baseline

    window_test_levels = [2.0, 3.0, 4.0, 6.0, 8.0]
    if not CONFIG["INCLUDE_BASELINE_LEVELS"]:
        window_test_levels = [2.0, 3.0, 4.0, 6.0]  # exclude 8s baseline

    axis_cases = ["acc_z_only", "gyro_drop", "acc_only", "gyro_only"]

    # all activities present in labels
    act_ids_in_labels = sorted(list(set([a for (_, a, _) in CONFIG["ALL_LABELS"]])))

    records = []  # row-wise results

    print("\n" + "=" * 90)
    print(">>> Robustness Evaluation: Clean-train → Corrupted-test (test-only)")
    print("=" * 90)

    # -----------------------------------------------------------------
    # For each activity: run LOSO training once per fold, test under all stressors
    # -----------------------------------------------------------------
    for act_id in act_ids_in_labels:
        act_name = CONFIG["TARGET_ACTIVITIES_MAP"].get(act_id, str(act_id))
        print("\n" + "-" * 90)
        print(f"[Activity] {act_id} | {act_name}")
        print("-" * 90)

        labels_act = [x for x in CONFIG["ALL_LABELS"] if x[1] == act_id]
        if len(labels_act) == 0:
            continue

        for fold_idx, test_subj in enumerate(subjects, start=1):
            set_strict_seed(CONFIG["seed"])

            train_labels = [x for x in labels_act if x[0] != test_subj]
            test_labels  = [x for x in labels_act if x[0] == test_subj]

            train_trials = prepare_trial_list(train_labels, full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])
            test_trials  = prepare_trial_list(test_labels,  full_data, CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["ACT_FEATURE_MAP"])

            if not test_trials or not train_trials:
                print(f"[Skip] Fold {fold_idx}: {test_subj} (missing train or test trials)")
                continue

            # TRAIN: clean, fixed 8s window
            train_data = trial_list_to_windows(
                train_trials,
                fs=CONFIG["fs"],
                win_sec=CONFIG["win_sec_train"],
                stride_sec=CONFIG["stride_sec_train"],
                drop_last=CONFIG["drop_last"]
            )

            g = torch.Generator()
            g.manual_seed(CONFIG["seed"])

            train_loader = DataLoader(
                TrialDataset(train_data),
                batch_size=CONFIG["batch_size"],
                shuffle=True,
                collate_fn=collate_variable_length,
                generator=g,
                num_workers=0
            )

            input_ch = train_data[0]['data'].shape[1]
            model = KAutoCountModel(
                input_ch=input_ch,
                hidden_dim=CONFIG["hidden_dim"],
                latent_dim=CONFIG["latent_dim"],
                K_max=CONFIG["K_max"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            for epoch in range(CONFIG["epochs"]):
                _ = train_one_epoch(model, train_loader, optimizer, CONFIG, device)
                scheduler.step()

            model.eval()

            # TEST trials (usually 1 per subject per activity)
            for item in test_trials:
                x_clean = item["data"]  # (T,C) normalized
                gt = float(item["count"])
                meta = item["meta"]

                # ---------- Noise: Gaussian ----------
                for sigma in noise_gauss_levels:
                    rng = _rng_from_meta(CONFIG["seed"], f"{meta}|gauss|{sigma}")
                    x_cor = add_gaussian_noise(x_clean, sigma=sigma, rng=rng)

                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_cor, gt,
                        fs_eval=CONFIG["fs"],
                        win_sec_eval=CONFIG["win_sec_train"],  # test win same as train for noise
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "noise_gaussian",
                        "level": float(sigma),
                        "fs_eval": float(CONFIG["fs"]),
                        "win_sec_eval": float(CONFIG["win_sec_train"]),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

                # ---------- Noise: Bias drift ----------
                for amp in noise_drift_levels:
                    rng = _rng_from_meta(CONFIG["seed"], f"{meta}|drift|{amp}")
                    x_cor = add_bias_drift(x_clean, fs=CONFIG["fs"], amplitude=amp, rng=rng, freq_hz=0.2)

                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_cor, gt,
                        fs_eval=CONFIG["fs"],
                        win_sec_eval=CONFIG["win_sec_train"],
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "noise_drift",
                        "level": float(amp),
                        "fs_eval": float(CONFIG["fs"]),
                        "win_sec_eval": float(CONFIG["win_sec_train"]),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

                # ---------- Noise: Spikes ----------
                for p in noise_spike_p_levels:
                    rng = _rng_from_meta(CONFIG["seed"], f"{meta}|spike|{p}|A{spike_A}")
                    x_cor = add_spikes(x_clean, p=p, A=spike_A, rng=rng)

                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_cor, gt,
                        fs_eval=CONFIG["fs"],
                        win_sec_eval=CONFIG["win_sec_train"],
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "noise_spike",
                        "level": float(p),
                        "fs_eval": float(CONFIG["fs"]),
                        "win_sec_eval": float(CONFIG["win_sec_train"]),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

                # ---------- Sampling rate (LPF → decimate) ----------
                for fs_new in fs_levels:
                    _ = _rng_from_meta(CONFIG["seed"], f"{meta}|fs|{fs_new}")  # kept for symmetry
                    x_ds = lowpass_then_decimate(x_clean, fs_orig=float(CONFIG["fs"]), fs_new=float(fs_new))

                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_ds, gt,
                        fs_eval=float(fs_new),
                        win_sec_eval=CONFIG["win_sec_train"],  # 8s in seconds stays 8s
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "sampling_rate",
                        "level": float(fs_new),
                        "fs_eval": float(fs_new),
                        "win_sec_eval": float(CONFIG["win_sec_train"]),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

                # ---------- Axis / Channel dropout cases (0-masking) ----------
                for case in axis_cases:
                    x_cor = apply_axis_dropout_case(x_clean, case_name=case)

                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_cor, gt,
                        fs_eval=CONFIG["fs"],
                        win_sec_eval=CONFIG["win_sec_train"],
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "axis_dropout",
                        "level": case,  # string level for cases
                        "fs_eval": float(CONFIG["fs"]),
                        "win_sec_eval": float(CONFIG["win_sec_train"]),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

                # ---------- Window length mismatch (stride=win/2, 50% overlap) ----------
                for win_test in window_test_levels:
                    pred, mae, mape = eval_one_trial_windowing(
                        model, x_clean, gt,
                        fs_eval=CONFIG["fs"],
                        win_sec_eval=float(win_test),
                        device=device,
                        tau=CONFIG["tau"],
                        batch_size=CONFIG["batch_size"]
                    )

                    records.append({
                        "activity_id": act_id,
                        "activity": act_name,
                        "fold": fold_idx,
                        "test_subj": test_subj,
                        "meta": meta,
                        "stressor": "window_mismatch",
                        "level": float(win_test),
                        "fs_eval": float(CONFIG["fs"]),
                        "win_sec_eval": float(win_test),
                        "gt": gt,
                        "pred": float(pred),
                        "mae": float(mae),
                        "mape": float(mape),
                    })

            print(f"[Done] Fold {fold_idx:2d} | Test: {test_subj}")

    # -----------------------------------------------------------------
    # Save results + aggregate summaries + plots
    # -----------------------------------------------------------------
    df = pd.DataFrame(records)
    out_csv = os.path.join(CONFIG["out_dir"], "robustness_raw.csv")
    df.to_csv(out_csv, index=False)
    print("\nSaved:", out_csv)

    # Aggregate (numeric levels only)
    # For axis_dropout (string level), handle separately.
    df_num = df.copy()
    df_num["level_num"] = pd.to_numeric(df_num["level"], errors="coerce")

    # Summary table (mean±std across folds/trials)
    agg_rows = []

    for (act_id, stressor), g in df_num.groupby(["activity_id", "stressor"]):
        act_name = g["activity"].iloc[0]

        # numeric levels
        g_num = g[~g["level_num"].isna()].copy()
        if len(g_num) > 0:
            for lvl, gg in g_num.groupby("level_num"):
                agg_rows.append({
                    "activity_id": act_id,
                    "activity": act_name,
                    "stressor": stressor,
                    "level": float(lvl),
                    "mae_mean": float(gg["mae"].mean()),
                    "mae_std": float(gg["mae"].std(ddof=0)),
                    "mape_mean": float(gg["mape"].mean()),
                    "mape_std": float(gg["mape"].std(ddof=0)),
                    "worst_mae": float(gg["mae"].max()),
                    "n": int(len(gg)),
                })

        # string levels (axis_dropout)
        g_str = g[g["level_num"].isna()].copy()
        if len(g_str) > 0:
            for lvl, gg in g_str.groupby("level"):
                agg_rows.append({
                    "activity_id": act_id,
                    "activity": act_name,
                    "stressor": stressor,
                    "level": str(lvl),
                    "mae_mean": float(gg["mae"].mean()),
                    "mae_std": float(gg["mae"].std(ddof=0)),
                    "mape_mean": float(gg["mape"].mean()),
                    "mape_std": float(gg["mape"].std(ddof=0)),
                    "worst_mae": float(gg["mae"].max()),
                    "n": int(len(gg)),
                })

    df_agg = pd.DataFrame(agg_rows)
    out_csv2 = os.path.join(CONFIG["out_dir"], "robustness_summary.csv")
    df_agg.to_csv(out_csv2, index=False)
    print("Saved:", out_csv2)

    # Plots: one per (activity, stressor) for numeric levels
    plot_dir = os.path.join(CONFIG["out_dir"], "plots")
    os.makedirs(plot_dir, exist_ok=True)

    for (act_id, stressor), g in df_agg.groupby(["activity_id", "stressor"]):
        g2 = g.copy()
        g2["level_num"] = pd.to_numeric(g2["level"], errors="coerce")
        g2 = g2[~g2["level_num"].isna()].sort_values("level_num")
        if len(g2) == 0:
            continue

        act_name = g2["activity"].iloc[0]
        title = f"[{act_name}] {stressor} degradation (MAE)"
        out_png = os.path.join(plot_dir, f"act{act_id}_{stressor}.png")
        g2_plot = g2.rename(columns={"level_num": "level"})
        save_degradation_plot(
            g2_plot,
            out_path=out_png,
            title=title,
            x_col="level_num",
            y_col="mae_mean",
            yerr_col="mae_std"
        )

    print(f"Saved plots to: {plot_dir}")
    print("\nDone.")


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

>>> Robustness Evaluation: Clean-train → Corrupted-test (test-only)

------------------------------------------------------------------------------------------
[Activity] 6 | Waist bends forward
------------------------------------------------------------------------------------------
[Done] Fold  1 | Test: subject1
[Done] Fold  2 | Test: subject2
[Done] Fold  3 | Test: subject3
[Done] Fold  4 | Test: subject4
[Done] Fold  5 | Test: subject5
[Done] Fold  6 | Test: subject6
[Done] Fold  7 | Test: subject7
[Done] Fold  8 | Test: subject8
[Done] Fold  9 | Test: subject9
[Done] Fold 10 | Test: subject10

------------------------------------------------------------------------------------------
[Activity] 7 | Frontal elevation of arms
------------------------------------------------------------------------------------------
[Done] Fold  1 | Test: subject1
[Done] Fold  2 | Test: subject2


KeyError: 'level_num'