In [None]:
# ============================================================
# Full sweep code (LOSO per-activity) for count-only K-auto + Windowing
#
# What it does:
# 1) Load mHealth once
# 2) Prebuild normalized trial cache for 6 activities x 10 subjects
# 3) Run 4 independent 1D sweeps in order:
#    lambda_smooth -> lambda_phase_ent -> lambda_recon -> lambda_effk
#    (each sweep includes 0, other lambdas fixed at base best)
# 4) For each (sweep_name, value):
#    - For each activity: run LOSO (10 folds)
#    - Report per-activity mean MAE/MSE/RMSE and overall macro average
#
# Notes:
# - Logging is minimal (per-activity + overall only).
# - This will be slow: (#sweep_values total) * (6 activities) * (10 folds) trainings.
#   Adjust CONFIG["epochs"] / sweep grids if needed.
# ============================================================

import os
import glob
import copy
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# ---------------------------------------------------------------------
# 1) Strict Seeding
# ---------------------------------------------------------------------
def set_strict_seed(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# ---------------------------------------------------------------------
# 2) Data Loading
# ---------------------------------------------------------------------
def load_mhealth_dataset(data_dir, target_activities_map, column_names):
    """
    Returns:
      full_dataset[subj_key][activity_name] = DataFrame(features + maybe others)
    """
    full_dataset = {}
    file_list = sorted(glob.glob(os.path.join(data_dir, "mHealth_subject*.log")))

    if not file_list:
        print(f"[Warning] No mHealth logs found in {data_dir}")
        return {}

    print(f"Loading {len(file_list)} subjects from {data_dir}...")

    for file_path in file_list:
        file_name = os.path.basename(file_path)
        subj_part = file_name.split('.')[0]
        try:
            subj_id_num = int(''.join(filter(str.isdigit, subj_part)))
            subj_key = f"subject{subj_id_num}"
        except:
            subj_key = subj_part

        try:
            df = pd.read_csv(file_path, sep="\t", header=None)
            df = df.iloc[:, :len(column_names)]
            df.columns = column_names

            subj_data = {}
            for label_code, activity_name in target_activities_map.items():
                activity_df = df[df['activity_id'] == label_code].copy()
                if not activity_df.empty:
                    subj_data[activity_name] = activity_df.drop(columns=['activity_id'])

            full_dataset[subj_key] = subj_data
        except Exception as e:
            print(f"Error loading {file_name}: {e}")

    return full_dataset


# ---------------------------------------------------------------------
# 2.5) Windowing (TRAIN only)
# ---------------------------------------------------------------------
def trial_list_to_windows(trial_list, fs, win_sec=8.0, stride_sec=4.0, drop_last=True):
    """
    TRAIN: trial -> sliding windows
    window label: trial-average rate * window duration
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    assert win_len > 0 and stride > 0

    windows = []
    for item in trial_list:
        x = item["data"]  # (T,C)
        T = x.shape[0]
        total_count = float(item["count"])
        meta = item["meta"]

        total_dur = max(T / float(fs), 1e-6)
        rate_trial = total_count / total_dur  # reps/s

        if T < win_len:
            win_dur = T / float(fs)
            windows.append({
                "data": x,
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[0:{T}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": 0,
                "win_end": T,
            })
            continue

        last_start = T - win_len
        starts = list(range(0, last_start + 1, stride))

        for st in starts:
            ed = st + win_len
            win_dur = win_len / float(fs)
            windows.append({
                "data": x[st:ed],
                "count": rate_trial * win_dur,
                "meta": f"{meta}__win[{st}:{ed}]",
                "parent_meta": meta,
                "parent_T": T,
                "win_start": st,
                "win_end": ed,
            })

        if not drop_last:
            last_st = starts[-1] + stride
            if last_st < T:
                ed = T
                win_dur = (ed - last_st) / float(fs)
                windows.append({
                    "data": x[last_st:ed],
                    "count": rate_trial * win_dur,
                    "meta": f"{meta}__win[{last_st}:{ed}]",
                    "parent_meta": meta,
                    "parent_T": T,
                    "win_start": last_st,
                    "win_end": ed,
                })

    return windows


# ---------------------------------------------------------------------
# 2.6) Windowing inference (TEST)
# ---------------------------------------------------------------------
def predict_count_by_windowing(model, x_np, fs, win_sec, stride_sec, device, tau=1.0, batch_size=64):
    """
    TEST: trial -> sliding windows inference -> window rate mean -> total count
    x_np: (T,C) already normalized
    return: pred_count(float)
    """
    win_len = int(round(win_sec * fs))
    stride = int(round(stride_sec * fs))
    T = x_np.shape[0]
    total_dur = T / float(fs)

    # short trial -> single forward
    if T <= win_len:
        x_tensor = torch.tensor(x_np, dtype=torch.float32).transpose(0, 1).unsqueeze(0).to(device)  # (1,C,T)
        with torch.no_grad():
            rate_hat, _, _, _ = model(x_tensor, mask=None, tau=tau)
        return float(rate_hat.item() * total_dur)

    starts = list(range(0, T - win_len + 1, stride))
    windows = np.stack([x_np[st:st + win_len] for st in starts], axis=0)  # (N, win_len, C)
    xw = torch.tensor(windows, dtype=torch.float32).permute(0, 2, 1).to(device)  # (N, C, win_len)

    rates = []
    model.eval()
    with torch.no_grad():
        for i in range(0, xw.shape[0], batch_size):
            xb = xw[i:i + batch_size]
            r_hat, _, _, _ = model(xb, mask=None, tau=tau)  # (B,)
            rates.append(r_hat.detach().cpu().numpy())

    rates = np.concatenate(rates, axis=0)  # (N,)
    rate_mean = float(rates.mean())
    pred_count = rate_mean * total_dur
    return float(pred_count)


# ---------------------------------------------------------------------
# 2.8) Dataset / Collate
# ---------------------------------------------------------------------
class TrialDataset(Dataset):
    def __init__(self, trial_list):
        self.trials = trial_list

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        item = self.trials[idx]
        data = torch.tensor(item['data'], dtype=torch.float32).transpose(0, 1)  # (C, T)
        count = torch.tensor(item['count'], dtype=torch.float32)
        return data, count, item['meta']


def collate_variable_length(batch):
    max_len = max([x[0].shape[1] for x in batch])
    C = batch[0][0].shape[0]

    padded_data, masks, counts, metas, lengths = [], [], [], [], []
    for data, count, meta in batch:
        T = data.shape[1]
        lengths.append(T)

        pad_size = max_len - T
        if pad_size > 0:
            pad = torch.zeros(C, pad_size)
            d_padded = torch.cat([data, pad], dim=1)
            mask = torch.cat([torch.ones(T), torch.zeros(pad_size)], dim=0)
        else:
            d_padded = data
            mask = torch.ones(T)

        padded_data.append(d_padded)
        masks.append(mask)
        counts.append(count)
        metas.append(meta)

    return {
        "data": torch.stack(padded_data),              # (B, C, T_max)
        "mask": torch.stack(masks),                    # (B, T_max)
        "count": torch.stack(counts),                  # (B,)
        "length": torch.tensor(lengths, dtype=torch.float32),  # (B,)
        "meta": metas
    }


# ---------------------------------------------------------------------
# 3) Model
# ---------------------------------------------------------------------
class ManifoldEncoder(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_ch, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, latent_dim, 1)
        )

    def forward(self, x):
        z = self.net(x)            # (B, D, T)
        z = z.transpose(1, 2)      # (B, T, D)
        return z


class ManifoldDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(latent_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, out_ch, 1)
        )

    def forward(self, z):
        zt = z.transpose(1, 2)     # (B, D, T)
        x_hat = self.net(zt)       # (B, C, T)
        return x_hat


class MultiRateHead(nn.Module):
    def __init__(self, latent_dim=16, hidden=64, K_max=6):
        super().__init__()
        self.K_max = K_max
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1 + K_max)  # [amp_logit | phase_logits...]
        )

    def forward(self, z, tau=1.0):
        out = self.net(z)                             # (B,T,1+K)
        amp = F.softplus(out[..., 0])                 # (B,T) >=0
        phase_logits = out[..., 1:]                   # (B,T,K)
        phase = F.softmax(phase_logits / tau, dim=-1) # (B,T,K)
        return amp, phase, phase_logits


class KAutoCountModel(nn.Module):
    def __init__(self, input_ch, hidden_dim=128, latent_dim=16, K_max=6):
        super().__init__()
        self.encoder = ManifoldEncoder(input_ch, hidden_dim, latent_dim)
        self.decoder = ManifoldDecoder(latent_dim, hidden_dim, input_ch)
        self.rate_head = MultiRateHead(latent_dim, hidden=hidden_dim, K_max=K_max)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        with torch.no_grad():
            b = self.rate_head.net[-1].bias
            b.zero_()
            b[0].fill_(-2.0)

    @staticmethod
    def _masked_mean_time(x, mask=None, eps=1e-6):
        if mask is None:
            return x.mean(dim=1)
        if x.dim() == 2:
            m = mask.to(dtype=x.dtype, device=x.device)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        elif x.dim() == 3:
            m = mask.to(dtype=x.dtype, device=x.device).unsqueeze(-1)
            return (x * m).sum(dim=1) / (m.sum(dim=1) + eps)
        else:
            raise ValueError(f"Unsupported dim for masked mean: {x.dim()}")

    def forward(self, x, mask=None, tau=1.0):
        z = self.encoder(x)              # (B,T,D)
        x_hat = self.decoder(z)          # (B,C,T)

        amp_t, phase_p, phase_logits = self.rate_head(z, tau=tau)  # (B,T), (B,T,K)
        micro_rate_t = amp_t                                  # (B,T)

        p_bar = self._masked_mean_time(phase_p, mask)          # (B,K)
        k_hat = 1.0 / (p_bar.pow(2).sum(dim=1) + 1e-6)         # (B,)

        rep_rate_t = micro_rate_t / (k_hat.unsqueeze(1) + 1e-6) # (B,T)
        if mask is not None:
            rep_rate_t = rep_rate_t * mask

        if mask is None:
            avg_rep_rate = rep_rate_t.mean(dim=1)
        else:
            avg_rep_rate = (rep_rate_t * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-6)

        aux = {
            "phase_p": phase_p,
            "rep_rate_t": rep_rate_t,
        }
        return avg_rep_rate, z, x_hat, aux


# ---------------------------------------------------------------------
# 4) Loss utils
# ---------------------------------------------------------------------
def masked_recon_mse(x_hat, x, mask, eps=1e-6):
    mask = mask.to(dtype=x.dtype, device=x.device)
    mask_bc = mask.unsqueeze(1)              # (B,1,T)
    se = (x_hat - x) ** 2                    # (B,C,T)
    se = se * mask_bc
    denom = (mask.sum() * x.shape[1]) + eps  # valid(B*T)*C
    return se.sum() / denom


def temporal_smoothness(v, mask=None, eps=1e-6):
    dv = torch.abs(v[:, 1:] - v[:, :-1])  # (B,T-1)
    if mask is None:
        return dv.mean()
    m = mask[:, 1:] * mask[:, :-1]
    m = m.to(dtype=dv.dtype, device=dv.device)
    return (dv * m).sum() / (m.sum() + eps)


def phase_entropy_loss(phase_p, mask=None, eps=1e-8):
    ent = -(phase_p * (phase_p + eps).log()).sum(dim=-1)  # (B,T)
    if mask is None:
        return ent.mean()
    ent = ent * mask
    return ent.sum() / (mask.sum() + eps)


def effK_usage_loss(phase_p, mask=None, eps=1e-6):
    if mask is None:
        p_bar = phase_p.mean(dim=1)  # (B,K)
    else:
        m = mask.to(dtype=phase_p.dtype, device=phase_p.device).unsqueeze(-1)  # (B,T,1)
        p_bar = (phase_p * m).sum(dim=1) / (m.sum(dim=1) + eps)

    effK = 1.0 / (p_bar.pow(2).sum(dim=1) + eps)
    return effK.mean()


# ---------------------------------------------------------------------
# 5) Train
# ---------------------------------------------------------------------
def train_one_epoch(model, loader, optimizer, config, device):
    model.train()

    fs = config["fs"]
    tau = config.get("tau", 1.0)

    lam_recon = config.get("lambda_recon", 1.0)
    lam_smooth = config.get("lambda_smooth", 0.05)
    lam_phase_ent = config.get("lambda_phase_ent", 0.01)
    lam_effk = config.get("lambda_effk", 0.0075)

    for batch in loader:
        x = batch["data"].to(device)         # (B,C,T)
        mask = batch["mask"].to(device)      # (B,T)
        y_count = batch["count"].to(device)  # (B,)
        length = batch["length"].to(device)  # (B,)

        duration = torch.clamp(length / fs, min=1e-6)  # sec
        y_rate = y_count / duration                    # reps/s

        optimizer.zero_grad()

        rate_hat, z, x_hat, aux = model(x, mask, tau=tau)

        loss_rate = F.mse_loss(rate_hat, y_rate)
        loss_recon = masked_recon_mse(x_hat, x, mask)
        loss_smooth = temporal_smoothness(aux["rep_rate_t"], mask)
        loss_phase_ent = phase_entropy_loss(aux["phase_p"], mask)
        loss_effk = effK_usage_loss(aux["phase_p"], mask)

        loss = (loss_rate
                + lam_recon * loss_recon
                + lam_smooth * loss_smooth
                + lam_phase_ent * loss_phase_ent
                + lam_effk * loss_effk)

        loss.backward()
        optimizer.step()


# ---------------------------------------------------------------------
# 6) Trial cache builder (precompute normalized trials once)
# ---------------------------------------------------------------------
def build_trial_cache(full_data, target_map, feature_map, count_table):
    """
    Returns:
      trials_cache[act_id][subj] = {
        'data': (T,C) normalized,
        'count': float,
        'meta': str
      }
    """
    trials_cache = {}
    for act_id, act_name in target_map.items():
        trials_cache[act_id] = {}
        feats = feature_map[act_id]

        if act_id not in count_table:
            raise ValueError(f"Missing COUNT_TABLE entry for act_id={act_id}")

        for subj, gt_count in count_table[act_id].items():
            if subj not in full_data or act_name not in full_data[subj]:
                print(f"[Skip-cache] Missing data for {subj} - {act_name}")
                continue

            raw_df = full_data[subj][act_name][feats]
            raw_np = raw_df.values.astype(np.float32)

            mean = raw_np.mean(axis=0)
            std = raw_np.std(axis=0) + 1e-6
            norm_np = (raw_np - mean) / std

            trials_cache[act_id][subj] = {
                "data": norm_np,                 # (T,C)
                "count": float(gt_count),         # trial total count
                "meta": f"{subj}_{act_name}"
            }

    return trials_cache


# ---------------------------------------------------------------------
# 7) LOSO runner (single activity)
# ---------------------------------------------------------------------
def eval_single_activity_loso(trials_cache_act, config, device):
    """
    trials_cache_act: dict subj -> trial_item (data/count/meta)
    Returns: (mae, mse, rmse) averaged over folds (subjects)
    """
    subjects = sorted(list(trials_cache_act.keys()))
    if len(subjects) == 0:
        return None

    fs = config["fs"]
    win_sec = config["win_sec"]
    stride_sec = config["stride_sec"]
    drop_last = config["drop_last"]
    tau = config.get("tau", 1.0)

    epochs = config["epochs"]
    lr = config["lr"]
    batch_size = config["batch_size"]

    hidden_dim = config["hidden_dim"]
    latent_dim = config["latent_dim"]
    K_max = config["K_max"]

    fold_abs = []
    fold_sq = []

    for test_subj in subjects:
        set_strict_seed(config["seed"])

        # build train/test trial lists
        train_trials = [trials_cache_act[s] for s in subjects if s != test_subj]
        test_trials = [trials_cache_act[test_subj]]

        # TRAIN: windowing
        train_windows = trial_list_to_windows(
            train_trials,
            fs=fs, win_sec=win_sec, stride_sec=stride_sec, drop_last=drop_last
        )

        # DataLoader
        g = torch.Generator()
        g.manual_seed(config["seed"])

        train_loader = DataLoader(
            TrialDataset(train_windows),
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_variable_length,
            generator=g,
            num_workers=0
        )

        input_ch = train_windows[0]["data"].shape[1]
        model = KAutoCountModel(
            input_ch=input_ch,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            K_max=K_max
        ).to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

        # train (✅ gradients enabled)
        for _ in range(epochs):
            train_one_epoch(model, train_loader, optimizer, config, device)
            scheduler.step()

        model.eval()

        # test (trial-level count via windowing inference)  ✅ no_grad only here
        item = test_trials[0]
        x_np = item["data"]
        gt = float(item["count"])
        with torch.no_grad():
            pred = predict_count_by_windowing(
                model, x_np=x_np, fs=fs, win_sec=win_sec, stride_sec=stride_sec,
                device=device, tau=tau, batch_size=batch_size
            )

        err = pred - gt
        fold_abs.append(abs(err))
        fold_sq.append(err * err)

        # free GPU memory fold-by-fold
        del model
        torch.cuda.empty_cache()

    mae = float(np.mean(fold_abs))
    mse = float(np.mean(fold_sq))
    rmse = float(np.sqrt(mse))
    return mae, mse, rmse


# ---------------------------------------------------------------------
# 8) Sweep helpers
# ---------------------------------------------------------------------
def make_sweep_values(base_val, multipliers=(0.25, 0.5, 1.0, 2.0, 4.0)):
    """
    Must include 0.
    Creates: [0] + [base_val * m for m in multipliers]
    Unique + sorted.
    """
    vals = [0.0]
    for m in multipliers:
        vals.append(float(base_val) * float(m))
    vals = sorted(set([float(v) for v in vals]))
    return vals


def run_one_setting(full_data_cache, base_config, device, sweep_name, sweep_value):
    """
    Returns:
      metrics_by_act: dict act_id -> (mae,mse,rmse)
      overall_macro: (mae,mse,rmse)
    """
    config = copy.deepcopy(base_config)
    config[sweep_name] = float(sweep_value)

    metrics_by_act = {}
    maes, mses, rmses = [], [], []

    for act_id in config["TARGET_ACTIVITIES_MAP"].keys():
        trials_cache_act = full_data_cache[act_id]
        out = eval_single_activity_loso(trials_cache_act, config, device)
        if out is None:
            continue
        mae, mse, rmse = out
        metrics_by_act[act_id] = (mae, mse, rmse)
        maes.append(mae)
        mses.append(mse)
        rmses.append(rmse)

    overall = (float(np.mean(maes)), float(np.mean(mses)), float(np.mean(rmses)))
    return config, metrics_by_act, overall


# ---------------------------------------------------------------------
# 9) Main
# ---------------------------------------------------------------------
def main():
    CONFIG = {
        "seed": 42,
        # TODO: adjust to your path
        "data_dir": "/content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET",

        "COLUMN_NAMES": [
            'acc_chest_x', 'acc_chest_y', 'acc_chest_z',
            'ecg_1', 'ecg_2',
            'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
            'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
            'mag_ankle_x', 'mag_ankle_y', 'mag_ankle_z',
            'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
            'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z',
            'mag_arm_x', 'mag_arm_y', 'mag_arm_z',
            'activity_id'
        ],

        # 6 activities
        "TARGET_ACTIVITIES_MAP": {
            6:  'Waist bends forward',
            7:  'Frontal elevation of arms',
            8:  'Knees bending',
            10: 'Jogging',
            11: 'Running',
            12: 'Jump front & back'
        },

        # fixed 15 channels (chest acc3 + ankle acc3 + ankle gyro3 + arm acc3 + arm gyro3)
        "ACT_FEATURE_MAP": {
            6:  ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            7:  ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            8:  ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            10: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            11: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
            12: ['acc_chest_x', 'acc_chest_y', 'acc_chest_z',
                 'acc_ankle_x', 'acc_ankle_y', 'acc_ankle_z',
                 'gyro_ankle_x', 'gyro_ankle_y', 'gyro_ankle_z',
                 'acc_arm_x', 'acc_arm_y', 'acc_arm_z',
                 'gyro_arm_x', 'gyro_arm_y', 'gyro_arm_z'],
        },

        # Training Params
        "epochs": 100, # 30
        "lr": 5e-4,
        "batch_size": 64,
        "fs": 50,

        # Windowing Params
        "win_sec": 8.0,
        "stride_sec": 4.0,
        "drop_last": True,

        # Model
        "hidden_dim": 128,
        "latent_dim": 16,
        "K_max": 6,

        # Base best loss weights
        "lambda_recon": 1.0,
        "lambda_smooth": 0.05,
        "lambda_phase_ent": 0.01,
        "lambda_effk": 0.0075,

        # temperature
        "tau": 1.0,

        # Count-only labels
        "COUNT_TABLE": {
            6: {
                "subject1": 21, "subject2": 19, "subject3": 21, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 21, "subject9": 21, "subject10": 20,
            },
            7: {
                "subject1": 20, "subject2": 20, "subject3": 20, "subject4": 20, "subject5": 20,
                "subject6": 20, "subject7": 20, "subject8": 19, "subject9": 19, "subject10": 20,
            },
            8: {
                "subject1": 20, "subject2": 21, "subject3": 21, "subject4": 19, "subject5": 20,
                "subject6": 20, "subject7": 21, "subject8": 21, "subject9": 21, "subject10": 21,
            },
            10: {
                "subject1": 157, "subject2": 161, "subject3": 154, "subject4": 154, "subject5": 160,
                "subject6": 156, "subject7": 153, "subject8": 160, "subject9": 166, "subject10": 156,
            },
            11: {
                "subject1": 165, "subject2": 158, "subject3": 174, "subject4": 163, "subject5": 157,
                "subject6": 172, "subject7": 149, "subject8": 166, "subject9": 174, "subject10": 172,
            },
            12: {
                "subject1": 20, "subject2": 22, "subject3": 21, "subject4": 21, "subject5": 20,
                "subject6": 21, "subject7": 19, "subject8": 20, "subject9": 20, "subject10": 20,
            },
        }
    }

    set_strict_seed(CONFIG["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # 1) load raw data once
    full_data = load_mhealth_dataset(CONFIG["data_dir"], CONFIG["TARGET_ACTIVITIES_MAP"], CONFIG["COLUMN_NAMES"])
    if not full_data:
        print("[Error] full_data is empty. Check path / file naming.")
        return

    # 2) build normalized trial cache once
    trials_cache_all = build_trial_cache(
        full_data=full_data,
        target_map=CONFIG["TARGET_ACTIVITIES_MAP"],
        feature_map=CONFIG["ACT_FEATURE_MAP"],
        count_table=CONFIG["COUNT_TABLE"]
    )

    # 3) sweep plan (order fixed)
    sweep_plan = [
        ("lambda_smooth", CONFIG["lambda_smooth"]),
        ("lambda_phase_ent", CONFIG["lambda_phase_ent"]),
        ("lambda_recon", CONFIG["lambda_recon"]),
        ("lambda_effk", CONFIG["lambda_effk"]),
    ]

    # ✅ User-specified sweep grids (must include 0)
    SWEEP_VALUES = {
        "lambda_recon":     [0, 0.25, 0.5, 0.75, 1.0, 1.5],
        "lambda_smooth":    [0, 0.01, 0.05, 0.1, 0.5, 1.0],
        "lambda_phase_ent": [0, 0.01, 0.05, 0.1, 0.5, 1.0],
        "lambda_effk":      [0, 0.005, 0.0075, 0.01, 0.05, 0.1],
    }

    print("\n" + "=" * 92)
    print("LOSS WEIGHT SENSITIVITY SWEEP (each sweep includes 0; others fixed at base best)")
    print("=" * 92)

    act_ids = list(CONFIG["TARGET_ACTIVITIES_MAP"].keys())
    act_names = CONFIG["TARGET_ACTIVITIES_MAP"]

    for sweep_name, base_val in sweep_plan:
        # ✅ use user grid (fallback: original make_sweep_values if missing key)
        if sweep_name in SWEEP_VALUES:
            sweep_vals = [float(v) for v in SWEEP_VALUES[sweep_name]]
        else:
            # (should not happen)
            sweep_vals = make_sweep_values(base_val, multipliers=(0.25, 0.5, 1.0, 2.0, 4.0))

        # keep unique + sorted, preserve 0
        sweep_vals = sorted(set(sweep_vals))

        print("\n" + "-" * 92)
        print(f"[SWEEP] {sweep_name}  (base={base_val})  values={sweep_vals}")
        print("-" * 92)

        for v in sweep_vals:
            cfg_used, metrics_by_act, overall = run_one_setting(
                full_data_cache=trials_cache_all,
                base_config=CONFIG,
                device=device,
                sweep_name=sweep_name,
                sweep_value=v
            )

            # ---- minimal log (per-activity + overall only) ----
            print(f"\n{sweep_name} = {v}")
            for act_id in act_ids:
                if act_id not in metrics_by_act:
                    continue
                mae, mse, rmse = metrics_by_act[act_id]
                print(f"  Act {act_id:>2d} ({act_names[act_id]}): "
                      f"MAE={mae:.4f} | MSE={mse:.4f} | RMSE={rmse:.4f}")

            o_mae, o_mse, o_rmse = overall
            print(f"  OVERALL (macro over 6 acts): MAE={o_mae:.4f} | MSE={o_mse:.4f} | RMSE={o_rmse:.4f}")

    print("\n" + "=" * 92)
    print("DONE.")
    print("=" * 92)


if __name__ == "__main__":
    main()


Device: cuda
Loading 10 subjects from /content/drive/MyDrive/Colab Notebooks/HAR_data/MHEALTHDATASET...

LOSS WEIGHT SENSITIVITY SWEEP (each sweep includes 0; others fixed at base best)

--------------------------------------------------------------------------------------------
[SWEEP] lambda_smooth  (base=0.05)  values=[0.0, 0.01, 0.05, 0.1, 0.5, 1.0]
--------------------------------------------------------------------------------------------

lambda_smooth = 0.0
  Act  6 (Waist bends forward): MAE=5.0936 | MSE=33.3969 | RMSE=5.7790
  Act  7 (Frontal elevation of arms): MAE=2.8823 | MSE=11.4221 | RMSE=3.3797
  Act  8 (Knees bending): MAE=4.6755 | MSE=39.0947 | RMSE=6.2526
  Act 10 (Jogging): MAE=8.4237 | MSE=188.6386 | RMSE=13.7346
  Act 11 (Running): MAE=13.9437 | MSE=258.0601 | RMSE=16.0642
  Act 12 (Jump front & back): MAE=1.4359 | MSE=3.0381 | RMSE=1.7430
  OVERALL (macro over 6 acts): MAE=6.0758 | MSE=88.9417 | RMSE=7.8255

lambda_smooth = 0.01
  Act  6 (Waist bends forward): MA