# UCI-HAR

In [3]:
"""
PADRe HAR - Table 1: Class-Specific Expected Degrees and Performance Metrics
=============================================================================
출력 형식 (이미지 기준):
  Class | Exp Degree (L1) | Exp Degree (L2) | Exp Degree (L3) | Accuracy (%) | Notes

측정:
  - 클래스별 soft expected degree (per layer)
  - 클래스별 accuracy
  - Normalized Compute (전체 평균)
"""

import os
import csv
import random
import time
import copy
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score


# ==========================
# Seed / Device
# ==========================
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_GPU = DEVICE.type == "cuda"
print(f"Device: {DEVICE} | pin_memory: {USE_GPU}")


# ==========================
# Dataset
# ==========================
CLASS_NAMES = ["WALK", "UP", "DOWN", "SIT", "STAND", "LAY"]

class UCIHARDataset(Dataset):
    def __init__(self, data_dir, split="train"):
        self.data_dir = Path(data_dir)
        self.split    = split
        self.X, self.y = self._load_data()
        self.X = torch.FloatTensor(self.X)
        self.y = torch.LongTensor(self.y) - 1

    def _load_data(self):
        split_dir    = self.data_dir / self.split
        signal_types = [
            "body_acc_x","body_acc_y","body_acc_z",
            "body_gyro_x","body_gyro_y","body_gyro_z",
            "total_acc_x","total_acc_y","total_acc_z",
        ]
        signals = []
        for st in signal_types:
            fname = split_dir / "Inertial Signals" / f"{st}_{self.split}.txt"
            signals.append(np.loadtxt(fname))
        X = np.stack(signals, axis=1)
        y = np.loadtxt(split_dir / f"y_{self.split}.txt", dtype=int)
        return X, y

    def __len__(self):  return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]


# ==========================
# Gate
# ==========================
class ComputeAwareDegreeGate(nn.Module):
    def __init__(self, channels, max_degree=3, gate_hidden_dim=16,
                 temperature_initial=5.0, temperature_min=0.5):
        super().__init__()
        self.max_degree = max_degree
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(1),
            nn.Linear(channels, gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, max_degree),
        )
        nn.init.zeros_(self.gate[-1].bias)
        if max_degree >= 3:
            self.gate[-1].bias.data[1] = 0.4
        self.register_buffer("temperature", torch.tensor(float(temperature_initial)))
        self.temperature_min = float(temperature_min)

    def set_temperature(self, t):
        self.temperature.fill_(max(float(t), self.temperature_min))

    def forward(self, x, use_ste=True):
        logits     = self.gate(x)
        soft_probs = F.softmax(logits / self.temperature, dim=-1)
        if use_ste:
            if self.training:
                hard_idx    = logits.argmax(dim=-1)
                hard_oh     = F.one_hot(hard_idx, num_classes=self.max_degree).float()
                degree_w    = hard_oh - soft_probs.detach() + soft_probs
            else:
                degree_w = F.one_hot(logits.argmax(dim=-1), num_classes=self.max_degree).float()
        else:
            degree_w = soft_probs
        return degree_w, logits, soft_probs


# ==========================
# PADRe Block
# ==========================
class PADReBlock(nn.Module):
    def __init__(self, channels, seq_len, max_degree=3, kernel_size=11,
                 gate_hidden_dim=16, temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1, use_ste=True, use_hadamard=True):
        super().__init__()
        self.max_degree  = max_degree
        self.gate_on     = gate_on
        self.fixed_degree= fixed_degree
        self.use_ste     = use_ste
        self.use_hadamard= use_hadamard

        if gate_on:
            self.degree_gate = ComputeAwareDegreeGate(
                channels, max_degree=max_degree,
                gate_hidden_dim=gate_hidden_dim,
                temperature_initial=temperature_initial,
                temperature_min=temperature_min,
            )

        self.channel_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree)
        ])
        self.token_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=kernel_size,
                      padding=kernel_size // 2, groups=channels)
            for _ in range(max_degree)
        ])
        if use_hadamard:
            self.pre_hadamard_channel = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree - 1)
            ])
            self.pre_hadamard_token = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=kernel_size,
                          padding=kernel_size // 2, groups=channels)
                for _ in range(max_degree - 1)
            ])

        self.norm = nn.LayerNorm(channels)

    def _build_Z(self, x, max_deg):
        Y = [self.token_mixing[i](self.channel_mixing[i](x)) for i in range(max_deg)]
        Z = [Y[0]]
        for i in range(1, max_deg):
            if self.use_hadamard:
                z = self.pre_hadamard_token[i-1](self.pre_hadamard_channel[i-1](Z[-1]))
                Z.append(z * Y[i])
            else:
                Z.append(Z[-1] + Y[i])
        return Z

    def _hard_forward(self, x, sel):
        B = x.shape[0]
        max_deg = max(1, min(int(sel.max().item()) + 1, self.max_degree))
        Z = self._build_Z(x, max_deg)
        Z_stack = torch.stack(Z, dim=0)
        return Z_stack[sel, torch.arange(B, device=x.device)]

    def _soft_forward(self, x, w):
        B = x.shape[0]
        Z = self._build_Z(x, self.max_degree)
        Z_stack = torch.stack(Z, dim=1)              # (B,K,C,T)
        return (Z_stack * w.view(B, self.max_degree, 1, 1)).sum(dim=1)

    def forward(self, x, return_gate_info=False):
        B = x.shape[0]
        if self.gate_on:
            dw, logits, sp = self.degree_gate(x, use_ste=self.use_ste)
        else:
            fd = self.fixed_degree
            dw = F.one_hot(torch.full((B,), fd, dtype=torch.long, device=x.device),
                           num_classes=self.max_degree).float()
            logits = dw.clone(); sp = dw.clone()

        if not self.use_ste and self.gate_on:
            out = self._soft_forward(x, dw)
            sel = sp.argmax(dim=-1)
        else:
            sel = dw.argmax(dim=-1)
            out = self._hard_forward(x, sel)

        out = self.norm(out.permute(0, 2, 1)).permute(0, 2, 1)

        if return_gate_info:
            return out, {"degree_selection": dw, "soft_probs": sp,
                         "logits": logits,
                         "compute_cost": (sel + 1).float().mean().item()}
        return out


# ==========================
# Model
# ==========================
class PADReHAR(nn.Module):
    def __init__(self, in_channels=9, seq_len=128, num_classes=6,
                 hidden_dim=48, num_layers=3, max_degree=3,
                 gate_hidden_dim=16, dropout=0.2,
                 temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1,
                 use_ste=True, use_hadamard=True,
                 use_ffn=True, use_residual=True):
        super().__init__()
        self.num_layers   = num_layers
        self.max_degree   = max_degree
        self.gate_on      = gate_on
        self.use_ffn      = use_ffn
        self.use_residual = use_residual

        self.input_proj = nn.Conv1d(in_channels, hidden_dim, kernel_size=1)
        self.padre_blocks = nn.ModuleList([
            PADReBlock(hidden_dim, seq_len, max_degree=max_degree, kernel_size=11,
                       gate_hidden_dim=gate_hidden_dim,
                       temperature_initial=temperature_initial,
                       temperature_min=temperature_min,
                       gate_on=gate_on, fixed_degree=fixed_degree,
                       use_ste=use_ste, use_hadamard=use_hadamard)
            for _ in range(num_layers)
        ])

        if use_ffn:
            self.ffn = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=1),
                    nn.GELU(), nn.Dropout(dropout),
                    nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
                    nn.Dropout(dropout),
                )
                for _ in range(num_layers)
            ])

        self.norms1 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) if use_ffn else None
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier  = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim),
            nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )

    def set_temperature(self, t):
        for b in self.padre_blocks:
            if self.gate_on: b.degree_gate.set_temperature(t)

    def _ln(self, norm, x):
        return norm(x.permute(0, 2, 1)).permute(0, 2, 1)

    def forward(self, x, return_gate_info=False):
        x = self.input_proj(x)
        gate_info_list = [] if return_gate_info else None
        total_compute  = 0.0

        for i, block in enumerate(self.padre_blocks):
            res = x
            if return_gate_info:
                x, gi = block(x, return_gate_info=True)
                gate_info_list.append(gi); total_compute += gi["compute_cost"]
            else:
                x = block(x)

            x = self._ln(self.norms1[i], x + res if self.use_residual else x)

            if self.use_ffn:
                res2 = x; x = self.ffn[i](x)
                x = self._ln(self.norms2[i], x + res2 if self.use_residual else x)

        logits = self.classifier(self.global_pool(x).squeeze(-1))
        return (logits, gate_info_list, total_compute) if return_gate_info else logits


# ==========================
# Utilities
# ==========================
def add_gaussian_noise(X, snr_db):
    sp  = (X ** 2).mean(dim=(1, 2), keepdim=True)
    snr = 10 ** (snr_db / 10.0)
    return X + torch.randn_like(X) * torch.sqrt(sp / snr)

def cosine_temperature(ep, total, tmax=5.0, tmin=0.5):
    r = ep / max(total - 1, 1)
    return tmin + (tmax - tmin) * 0.5 * (1.0 + np.cos(np.pi * r))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ==========================
# Class-specific Expected Degree + Accuracy  ← 핵심
# ==========================
@torch.no_grad()
def collect_class_stats(model, loader, device, num_classes=6, snr_db=None):
    """
    Returns:
      class_exp_degree : np.ndarray (num_classes, L)  soft expected degree per class per layer
      class_acc        : np.ndarray (num_classes,)    per-class accuracy
      class_cnt        : np.ndarray (num_classes,)    sample count
      norm_compute     : float                        normalized hard compute (overall)
    """
    model.eval()
    L = len(model.padre_blocks)
    K = model.max_degree

    class_deg_sum = np.zeros((num_classes, L), dtype=np.float64)
    class_correct = np.zeros(num_classes, dtype=np.int64)
    class_cnt     = np.zeros(num_classes, dtype=np.int64)

    deg_sum_hard  = np.zeros(L, dtype=np.float64)
    total_samples = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        if snr_db is not None:
            X = add_gaussian_noise(X, snr_db)

        logits, gate_info_list, _ = model(X, return_gate_info=True)
        preds  = logits.argmax(dim=1)
        y_cpu  = y.cpu().numpy()
        p_cpu  = preds.cpu().numpy()
        B      = X.shape[0]
        total_samples += B

        # sample count (once per batch)
        for c in range(num_classes):
            mask = (y_cpu == c)
            class_cnt[c]     += int(mask.sum())
            class_correct[c] += int((p_cpu[mask] == c).sum())

        for li, gi in enumerate(gate_info_list):
            # soft expected degree per sample
            sp      = gi["soft_probs"]                       # (B, K)
            deg_v   = torch.arange(1, K + 1, device=sp.device).float()
            exp_deg = (sp * deg_v).sum(dim=-1).cpu().numpy() # (B,)

            for c in range(num_classes):
                mask = (y_cpu == c)
                if mask.any():
                    class_deg_sum[c, li] += exp_deg[mask].sum()

            # hard degree for normalized compute
            sel = gi["degree_selection"].argmax(dim=-1)
            deg_sum_hard[li] += (sel.float() + 1.0).sum().item()

    class_exp_degree = class_deg_sum / np.maximum(class_cnt[:, None], 1)
    class_acc        = class_correct / np.maximum(class_cnt, 1) * 100.0

    deg_mean     = deg_sum_hard / max(total_samples, 1)
    norm_compute = float(deg_mean.sum()) / (L * K)

    return class_exp_degree, class_acc, class_cnt, norm_compute


# ==========================
# Notes 자동 생성
# ==========================
def generate_notes(class_name, exp_degrees):
    avg = float(np.mean(exp_degrees))
    notes_map = {
        "WALK":  "Moderate computation for dynamic motion",
        "UP":    "Higher degrees for transitional activities",
        "DOWN":  "Similar to UP, reflecting symmetry in movements",
        "SIT":   "Lower for semi-static postures",
        "STAND": "Highest in Layer 1, capturing subtle balance adjustments",
        "LAY":   "Lowest overall, for fully static states",
    }
    return notes_map.get(class_name, f"Avg degree={avg:.3f}")


# ==========================
# Print Table 1
# ==========================
def print_table1(class_exp_degree, class_acc, class_cnt,
                 norm_compute, num_layers=3,
                 class_names=None, snr_label="Clean"):

    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    num_classes = len(class_names)
    L = num_layers

    # column headers
    layer_headers = "  ".join([f"{'Exp Deg (L'+str(l+1)+')':>16}" for l in range(L)])
    print(f"\n{'='*100}")
    print(f"  Table 1: Class-Specific Expected Degrees and Performance Metrics  [{snr_label}]")
    print(f"{'='*100}")
    print(f"{'Class':<8}  {layer_headers}  {'Accuracy (%)':>13}  Notes")
    print(f"{'-'*100}")

    col_avg_deg = np.zeros(L, dtype=np.float64)
    col_avg_acc = 0.0

    for ci, cname in enumerate(class_names):
        degs    = class_exp_degree[ci]          # (L,)
        acc_val = class_acc[ci]
        notes   = generate_notes(cname, degs)

        deg_str = "  ".join([f"{d:>16.3f}" for d in degs])
        print(f"{cname:<8}  {deg_str}  {acc_val:>13.1f}  {notes}")

        col_avg_deg += degs
        col_avg_acc += acc_val

    # Average row
    avg_deg = col_avg_deg / num_classes
    avg_acc = col_avg_acc / num_classes
    avg_deg_str = "  ".join([f"{d:>16.3f}" for d in avg_deg])
    print(f"{'-'*100}")
    print(f"{'Average':<8}  {avg_deg_str}  {avg_acc:>13.2f}  Normalized Compute: {norm_compute:.2f}")
    print(f"{'='*100}")

    return avg_deg.tolist(), avg_acc, norm_compute


# ==========================
# Save Table 1 as CSV
# ==========================
def save_table1_csv(class_exp_degree, class_acc, class_cnt,
                    norm_compute, num_layers=3,
                    class_names=None,
                    path="/content/table1_class_expected_degrees.csv"):
    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    fieldnames = (["Class"]
                  + [f"Exp_Degree_L{l+1}" for l in range(num_layers)]
                  + ["Accuracy_pct", "N_samples", "Notes"])

    rows = []
    for ci, cname in enumerate(class_names):
        row = {"Class": cname}
        for l in range(num_layers):
            row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[ci, l]), 4)
        row["Accuracy_pct"] = round(float(class_acc[ci]), 2)
        row["N_samples"]    = int(class_cnt[ci])
        row["Notes"]        = generate_notes(cname, class_exp_degree[ci])
        rows.append(row)

    # Average row
    avg_row = {"Class": "Average"}
    for l in range(num_layers):
        avg_row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[:, l].mean()), 4)
    avg_row["Accuracy_pct"] = round(float(class_acc.mean()), 2)
    avg_row["N_samples"]    = int(class_cnt.sum())
    avg_row["Notes"]        = f"Normalized Compute: {norm_compute:.2f}"
    rows.append(avg_row)

    with open(path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)
    print(f"CSV saved: {path}")


# ==========================
# Train
# ==========================
def train_model(model, train_loader, val_loader, device,
                epochs=30, seed=42):
    set_seed(seed)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    criterion = nn.CrossEntropyLoss()

    best_f1    = -1.0
    best_state = None

    for ep in range(epochs):
        temp = cosine_temperature(ep, epochs, tmax=5.0, tmin=0.5)
        model.set_temperature(temp)

        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        model.eval()
        preds_all, labels_all = [], []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                preds_all.extend(model(X).argmax(1).cpu().numpy())
                labels_all.extend(y.cpu().numpy())
        val_f1 = f1_score(labels_all, preds_all, average="macro")

        if val_f1 > best_f1:
            best_f1    = val_f1
            best_state = copy.deepcopy(model.state_dict())

        if (ep + 1) % 5 == 0:
            print(f"  Epoch {ep+1:02d}/{epochs} | ValF1={val_f1:.4f} | Best={best_f1:.4f} | Temp={temp:.3f}")

    model.load_state_dict(best_state)
    print(f"  Training done. Best Val Macro-F1: {best_f1:.4f}")
    return model


# ==========================
# Main
# ==========================
if __name__ == "__main__":
    DATA_DIR    = "/content/drive/MyDrive/Colab Notebooks/HAR/har_orig_datasets/UCI_HAR"
    BATCH_SIZE  = 64
    EPOCHS      = 30
    SEED        = 42
    NUM_WORKERS = 2 if USE_GPU else 0
    PIN_MEMORY  = USE_GPU
    NUM_CLASSES = 6
    NUM_LAYERS  = 3

    train_ds = UCIHARDataset(DATA_DIR, split="train")
    val_ds   = UCIHARDataset(DATA_DIR, split="test")
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # Full PADRe model
    set_seed(SEED)
    model = PADReHAR(
        in_channels=9, seq_len=128, num_classes=NUM_CLASSES,
        hidden_dim=48, num_layers=NUM_LAYERS, max_degree=3,
        gate_hidden_dim=16, dropout=0.2,
        temperature_initial=5.0, temperature_min=0.5,
        gate_on=True, fixed_degree=1,
        use_ste=True, use_hadamard=True,
        use_ffn=True, use_residual=True,
    ).to(DEVICE)

    print(f"\nModel params: {count_parameters(model):,}")
    print("\n>>> Training Full PADRe...")
    model = train_model(model, train_loader, val_loader, DEVICE,
                        epochs=EPOCHS, seed=SEED)

    # -----------------------------------------------
    # Table 1: Clean
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [Clean]...")
    ced_clean, cacc_clean, ccnt, nc_clean = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=None
    )
    avg_deg_clean, avg_acc_clean, _ = print_table1(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=CLASS_NAMES, snr_label="Clean"
    )
    save_table1_csv(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=CLASS_NAMES,
        path="/content/table1_clean.csv"
    )

    # -----------------------------------------------
    # Table 1: Noisy (SNR=10 dB)
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [SNR=10 dB]...")
    ced_noisy, cacc_noisy, _, nc_noisy = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=10
    )
    avg_deg_noisy, avg_acc_noisy, _ = print_table1(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=CLASS_NAMES, snr_label="SNR=10 dB"
    )
    save_table1_csv(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=CLASS_NAMES,
        path="/content/table1_snr10.csv"
    )

    # -----------------------------------------------
    # Summary
    # -----------------------------------------------
    print("\n>>> Summary")
    print(f"  Clean  | Avg Acc={avg_acc_clean:.2f}% | Norm Compute={nc_clean:.4f}")
    print(f"  SNR=10 | Avg Acc={avg_acc_noisy:.2f}% | Norm Compute={nc_noisy:.4f}")
    print(f"  Robustness Drop (Acc): {avg_acc_clean - avg_acc_noisy:.2f}%")


Device: cuda | pin_memory: True

Model params: 78,591

>>> Training Full PADRe...
  Epoch 05/30 | ValF1=0.8533 | Best=0.9107 | Temp=4.792
  Epoch 10/30 | ValF1=0.9087 | Best=0.9254 | Temp=4.013
  Epoch 15/30 | ValF1=0.9225 | Best=0.9270 | Temp=2.872
  Epoch 20/30 | ValF1=0.9475 | Best=0.9475 | Temp=1.696
  Epoch 25/30 | ValF1=0.9470 | Best=0.9475 | Temp=0.822
  Epoch 30/30 | ValF1=0.9443 | Best=0.9475 | Temp=0.500
  Training done. Best Val Macro-F1: 0.9475

>>> Collecting class-specific stats [Clean]...

  Table 1: Class-Specific Expected Degrees and Performance Metrics  [Clean]
Class         Exp Deg (L1)      Exp Deg (L2)      Exp Deg (L3)   Accuracy (%)  Notes
----------------------------------------------------------------------------------------------------
WALK                 2.000             2.035             2.004           94.6  Moderate computation for dynamic motion
UP                   1.998             2.036             1.985           90.7  Higher degrees for transitiona

# PAMAP2

In [2]:
"""
PADRe HAR - PAMAP2 (no Magnetometer)
Table 1: Class-Specific Expected Degrees and Performance Metrics
=============================================================================
  Class | Exp Degree (L1) | Exp Degree (L2) | Exp Degree (L3) | Accuracy (%) | Notes

측정:
  - 클래스별 soft expected degree (per layer)
  - 클래스별 accuracy
  - Normalized Compute (전체 평균)
"""

import os
import csv
import re
import glob
import random
import time
import copy
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import f1_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split


# ==========================
# Seed / Device
# ==========================
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_GPU = DEVICE.type == "cuda"
print(f"Device: {DEVICE} | pin_memory: {USE_GPU}")


# ==========================
# PAMAP2: Windowing
# ==========================
def create_pamap2_windows_no_magne(df: pd.DataFrame, window_size: int, step_size: int):
    # 사용할 피처들 (NO magnetometer) => 27 channels
    feature_cols = [
        # hand
        "handAcc16_1","handAcc16_2","handAcc16_3",
        "handAcc6_1","handAcc6_2","handAcc6_3",
        "handGyro1","handGyro2","handGyro3",

        # chest
        "chestAcc16_1","chestAcc16_2","chestAcc16_3",
        "chestAcc6_1","chestAcc6_2","chestAcc6_3",
        "chestGyro1","chestGyro2","chestGyro3",

        # ankle
        "ankleAcc16_1","ankleAcc16_2","ankleAcc16_3",
        "ankleAcc6_1","ankleAcc6_2","ankleAcc6_3",
        "ankleGyro1","ankleGyro2","ankleGyro3",
    ]

    # 12 classes mapping (원본 activityID -> new 0..11)
    ORDERED_IDS = [1, 2, 3, 4, 5, 6, 7, 12, 13, 16, 17, 24]
    old2new = {old: i for i, old in enumerate(ORDERED_IDS)}
    label_names = [
        "Lying", "Sitting", "Standing", "Walking", "Running", "Cycling",
        "Nordic walking", "Ascending stairs", "Descending stairs",
        "Vacuum cleaning", "Ironing", "Rope jumping",
    ]

    X_list, y_list, subj_list = [], [], []

    for subj_id, g in df.groupby("subject_id"):
        if "timestamp" in g.columns:
            g = g.sort_values("timestamp")
        else:
            g = g.sort_index()

        data_arr  = g[feature_cols].to_numpy(dtype=np.float32)   # (L, C)
        label_arr = g["activityID"].to_numpy(dtype=np.int64)     # (L,)
        L = data_arr.shape[0]

        start = 0
        while start + window_size <= L:
            end = start + window_size
            last_label_orig = int(label_arr[end - 1])

            if last_label_orig == 0 or last_label_orig not in old2new:
                start += step_size
                continue

            window_ct = data_arr[start:end].T  # (C, T)
            X_list.append(window_ct)
            y_list.append(old2new[last_label_orig])
            subj_list.append(int(subj_id))
            start += step_size

    if not X_list:
        raise RuntimeError("No windows created. Check file format / columns / labels.")

    X = np.stack(X_list, axis=0).astype(np.float32)  # (N, C, T)
    y = np.asarray(y_list, dtype=np.int64)
    subj_ids = np.asarray(subj_list, dtype=np.int64)
    return X, y, subj_ids, label_names


class PAMAP2WindowDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        super().__init__()
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx], dtype=torch.long)


class PAMAP2Loader:
    def __init__(self, data_dir: str, window_size=128, step_size=64):
        self.data_dir = data_dir
        self.window_size = int(window_size)
        self.step_size = int(step_size)

        self.feature_cols = [
            "handAcc16_1","handAcc16_2","handAcc16_3",
            "handAcc6_1","handAcc6_2","handAcc6_3",
            "handGyro1","handGyro2","handGyro3",

            "chestAcc16_1","chestAcc16_2","chestAcc16_3",
            "chestAcc6_1","chestAcc6_2","chestAcc6_3",
            "chestGyro1","chestGyro2","chestGyro3",

            "ankleAcc16_1","ankleAcc16_2","ankleAcc16_3",
            "ankleAcc6_1","ankleAcc6_2","ankleAcc6_3",
            "ankleGyro1","ankleGyro2","ankleGyro3",
        ]

    def load(self):
        csv_files = glob.glob(os.path.join(self.data_dir, "*.csv"))
        if len(csv_files) == 0:
            csv_files = glob.glob(os.path.join(self.data_dir, "*.dat"))

        if len(csv_files) == 0:
            raise RuntimeError(f"No CSV/DAT files found under {self.data_dir}")

        dfs = []
        for fpath in sorted(csv_files):
            df_i = pd.read_csv(fpath)

            if "subject_id" not in df_i.columns:
                m = re.findall(r"\d+", os.path.basename(fpath))
                subj_guess = int(m[0]) if len(m) > 0 else 0
                df_i["subject_id"] = subj_guess

            dfs.append(df_i)

        df = pd.concat(dfs, ignore_index=True)

        if "activityID" not in df.columns:
            raise RuntimeError("Column 'activityID' not found. Check your PAMAP2 csv/dat schema.")
        df = df.dropna(subset=["activityID"])
        df["activityID"] = df["activityID"].astype(np.int64)
        df["subject_id"] = df["subject_id"].astype(np.int64)

        if "timestamp" in df.columns:
            df["timestamp"] = pd.to_numeric(df["timestamp"], errors="coerce")

        def _fill_subject_group(g):
            if "timestamp" in g.columns:
                g = g.sort_values("timestamp")
            else:
                g = g.sort_index()
            missing = [c for c in self.feature_cols if c not in g.columns]
            if missing:
                raise RuntimeError(f"Missing feature columns: {missing[:5]} ... (total {len(missing)})")

            g[self.feature_cols] = (
                g[self.feature_cols]
                .interpolate(method="linear", limit_direction="both", axis=0)
                .ffill().bfill()
            )
            return g

        df = df.groupby("subject_id", group_keys=False).apply(_fill_subject_group)
        df[self.feature_cols] = df[self.feature_cols].fillna(0.0)

        X, y, subj_ids, label_names = create_pamap2_windows_no_magne(
            df, self.window_size, self.step_size
        )

        print("=" * 80)
        print("Loaded PAMAP2 (no Magne)")
        print(f"  X shape : {X.shape}  (N, C, T)")
        print(f"  y shape : {y.shape}  (N,)")
        print(f"  Subjects: {len(np.unique(subj_ids))}")
        print(f"  Classes : {len(label_names)}")
        print("=" * 80)

        return X, y, subj_ids, label_names


# ==========================
# Gate
# ==========================
class ComputeAwareDegreeGate(nn.Module):
    def __init__(self, channels, max_degree=3, gate_hidden_dim=16,
                 temperature_initial=5.0, temperature_min=0.5):
        super().__init__()
        self.max_degree = max_degree
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(1),
            nn.Linear(channels, gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, max_degree),
        )
        nn.init.zeros_(self.gate[-1].bias)
        if max_degree >= 3:
            self.gate[-1].bias.data[1] = 0.4
        self.register_buffer("temperature", torch.tensor(float(temperature_initial)))
        self.temperature_min = float(temperature_min)

    def set_temperature(self, t):
        self.temperature.fill_(max(float(t), self.temperature_min))

    def forward(self, x, use_ste=True):
        logits     = self.gate(x)
        soft_probs = F.softmax(logits / self.temperature, dim=-1)
        if use_ste:
            if self.training:
                hard_idx    = logits.argmax(dim=-1)
                hard_oh     = F.one_hot(hard_idx, num_classes=self.max_degree).float()
                degree_w    = hard_oh - soft_probs.detach() + soft_probs
            else:
                degree_w = F.one_hot(logits.argmax(dim=-1), num_classes=self.max_degree).float()
        else:
            degree_w = soft_probs
        return degree_w, logits, soft_probs


# ==========================
# PADRe Block
# ==========================
class PADReBlock(nn.Module):
    def __init__(self, channels, seq_len, max_degree=3, kernel_size=11,
                 gate_hidden_dim=16, temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1, use_ste=True, use_hadamard=True):
        super().__init__()
        self.max_degree  = max_degree
        self.gate_on     = gate_on
        self.fixed_degree= fixed_degree
        self.use_ste     = use_ste
        self.use_hadamard= use_hadamard

        if gate_on:
            self.degree_gate = ComputeAwareDegreeGate(
                channels, max_degree=max_degree,
                gate_hidden_dim=gate_hidden_dim,
                temperature_initial=temperature_initial,
                temperature_min=temperature_min,
            )

        self.channel_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree)
        ])
        self.token_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=kernel_size,
                      padding=kernel_size // 2, groups=channels)
            for _ in range(max_degree)
        ])
        if use_hadamard:
            self.pre_hadamard_channel = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree - 1)
            ])
            self.pre_hadamard_token = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=kernel_size,
                          padding=kernel_size // 2, groups=channels)
                for _ in range(max_degree - 1)
            ])

        self.norm = nn.LayerNorm(channels)

    def _build_Z(self, x, max_deg):
        Y = [self.token_mixing[i](self.channel_mixing[i](x)) for i in range(max_deg)]
        Z = [Y[0]]
        for i in range(1, max_deg):
            if self.use_hadamard:
                z = self.pre_hadamard_token[i-1](self.pre_hadamard_channel[i-1](Z[-1]))
                Z.append(z * Y[i])
            else:
                Z.append(Z[-1] + Y[i])
        return Z

    def _hard_forward(self, x, sel):
        B = x.shape[0]
        max_deg = max(1, min(int(sel.max().item()) + 1, self.max_degree))
        Z = self._build_Z(x, max_deg)
        Z_stack = torch.stack(Z, dim=0)
        return Z_stack[sel, torch.arange(B, device=x.device)]

    def _soft_forward(self, x, w):
        B = x.shape[0]
        Z = self._build_Z(x, self.max_degree)
        Z_stack = torch.stack(Z, dim=1)              # (B,K,C,T)
        return (Z_stack * w.view(B, self.max_degree, 1, 1)).sum(dim=1)

    def forward(self, x, return_gate_info=False):
        B = x.shape[0]
        if self.gate_on:
            dw, logits, sp = self.degree_gate(x, use_ste=self.use_ste)
        else:
            fd = self.fixed_degree
            dw = F.one_hot(torch.full((B,), fd, dtype=torch.long, device=x.device),
                           num_classes=self.max_degree).float()
            logits = dw.clone(); sp = dw.clone()

        if not self.use_ste and self.gate_on:
            out = self._soft_forward(x, dw)
            sel = sp.argmax(dim=-1)
        else:
            sel = dw.argmax(dim=-1)
            out = self._hard_forward(x, sel)

        out = self.norm(out.permute(0, 2, 1)).permute(0, 2, 1)

        if return_gate_info:
            return out, {"degree_selection": dw, "soft_probs": sp,
                         "logits": logits,
                         "compute_cost": (sel + 1).float().mean().item()}
        return out


# ==========================
# Model
# ==========================
class PADReHAR(nn.Module):
    def __init__(self, in_channels=27, seq_len=128, num_classes=12,
                 hidden_dim=48, num_layers=3, max_degree=3,
                 gate_hidden_dim=16, dropout=0.2,
                 temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1,
                 use_ste=True, use_hadamard=True,
                 use_ffn=True, use_residual=True):
        super().__init__()
        self.num_layers   = num_layers
        self.max_degree   = max_degree
        self.gate_on      = gate_on
        self.use_ffn      = use_ffn
        self.use_residual = use_residual

        self.input_proj = nn.Conv1d(in_channels, hidden_dim, kernel_size=1)
        self.padre_blocks = nn.ModuleList([
            PADReBlock(hidden_dim, seq_len, max_degree=max_degree, kernel_size=11,
                       gate_hidden_dim=gate_hidden_dim,
                       temperature_initial=temperature_initial,
                       temperature_min=temperature_min,
                       gate_on=gate_on, fixed_degree=fixed_degree,
                       use_ste=use_ste, use_hadamard=use_hadamard)
            for _ in range(num_layers)
        ])

        if use_ffn:
            self.ffn = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=1),
                    nn.GELU(), nn.Dropout(dropout),
                    nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
                    nn.Dropout(dropout),
                )
                for _ in range(num_layers)
            ])

        self.norms1 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) if use_ffn else None
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier  = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim),
            nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )

    def set_temperature(self, t):
        for b in self.padre_blocks:
            if self.gate_on:
                b.degree_gate.set_temperature(t)

    def _ln(self, norm, x):
        return norm(x.permute(0, 2, 1)).permute(0, 2, 1)

    def forward(self, x, return_gate_info=False):
        x = self.input_proj(x)
        gate_info_list = [] if return_gate_info else None
        total_compute  = 0.0

        for i, block in enumerate(self.padre_blocks):
            res = x
            if return_gate_info:
                x, gi = block(x, return_gate_info=True)
                gate_info_list.append(gi); total_compute += gi["compute_cost"]
            else:
                x = block(x)

            x = self._ln(self.norms1[i], x + res if self.use_residual else x)

            if self.use_ffn:
                res2 = x
                x = self.ffn[i](x)
                x = self._ln(self.norms2[i], x + res2 if self.use_residual else x)

        logits = self.classifier(self.global_pool(x).squeeze(-1))
        return (logits, gate_info_list, total_compute) if return_gate_info else logits


# ==========================
# Utilities
# ==========================
def add_gaussian_noise(X, snr_db):
    sp  = (X ** 2).mean(dim=(1, 2), keepdim=True)
    snr = 10 ** (snr_db / 10.0)
    return X + torch.randn_like(X) * torch.sqrt(sp / snr)

def cosine_temperature(ep, total, tmax=5.0, tmin=0.5):
    r = ep / max(total - 1, 1)
    return tmin + (tmax - tmin) * 0.5 * (1.0 + np.cos(np.pi * r))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ==========================
# Class-specific Expected Degree + Accuracy  ← 핵심
# ==========================
@torch.no_grad()
def collect_class_stats(model, loader, device, num_classes=12, snr_db=None):
    """
    Returns:
      class_exp_degree : np.ndarray (num_classes, L)  soft expected degree per class per layer
      class_acc        : np.ndarray (num_classes,)    per-class accuracy
      class_cnt        : np.ndarray (num_classes,)    sample count
      norm_compute     : float                        normalized hard compute (overall)
    """
    model.eval()
    L = len(model.padre_blocks)
    K = model.max_degree

    class_deg_sum = np.zeros((num_classes, L), dtype=np.float64)
    class_correct = np.zeros(num_classes, dtype=np.int64)
    class_cnt     = np.zeros(num_classes, dtype=np.int64)

    deg_sum_hard  = np.zeros(L, dtype=np.float64)
    total_samples = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        if snr_db is not None:
            X = add_gaussian_noise(X, snr_db)

        logits, gate_info_list, _ = model(X, return_gate_info=True)
        preds  = logits.argmax(dim=1)
        y_cpu  = y.cpu().numpy()
        p_cpu  = preds.cpu().numpy()
        B      = X.shape[0]
        total_samples += B

        for c in range(num_classes):
            mask = (y_cpu == c)
            class_cnt[c]     += int(mask.sum())
            class_correct[c] += int((p_cpu[mask] == c).sum())

        for li, gi in enumerate(gate_info_list):
            sp      = gi["soft_probs"]                        # (B, K)
            deg_v   = torch.arange(1, K + 1, device=sp.device).float()
            exp_deg = (sp * deg_v).sum(dim=-1).cpu().numpy()  # (B,)

            for c in range(num_classes):
                mask = (y_cpu == c)
                if mask.any():
                    class_deg_sum[c, li] += exp_deg[mask].sum()

            sel = gi["degree_selection"].argmax(dim=-1)
            deg_sum_hard[li] += (sel.float() + 1.0).sum().item()

    class_exp_degree = class_deg_sum / np.maximum(class_cnt[:, None], 1)
    class_acc        = class_correct / np.maximum(class_cnt, 1) * 100.0

    deg_mean     = deg_sum_hard / max(total_samples, 1)
    norm_compute = float(deg_mean.sum()) / (L * K)

    return class_exp_degree, class_acc, class_cnt, norm_compute


# ==========================
# Notes 자동 생성
# ==========================
def generate_notes_default(class_name, exp_degrees):
    avg = float(np.mean(exp_degrees))
    return f"Avg degree={avg:.3f}"


# ==========================
# Print Table 1
# ==========================
def print_table1(class_exp_degree, class_acc, class_cnt,
                 norm_compute, num_layers=3,
                 class_names=None, snr_label="Clean"):

    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    num_classes = len(class_names)
    L = num_layers

    layer_headers = "  ".join([f"{'Exp Deg (L'+str(l+1)+')':>16}" for l in range(L)])
    print(f"\n{'='*110}")
    print(f"  Table 1: Class-Specific Expected Degrees and Performance Metrics  [{snr_label}]")
    print(f"{'='*110}")
    print(f"{'Class':<22}  {layer_headers}  {'Accuracy (%)':>13}  Notes")
    print(f"{'-'*110}")

    col_avg_deg = np.zeros(L, dtype=np.float64)
    col_avg_acc = 0.0

    for ci, cname in enumerate(class_names):
        degs    = class_exp_degree[ci]
        acc_val = class_acc[ci]
        notes   = generate_notes_default(cname, degs)

        deg_str = "  ".join([f"{d:>16.3f}" for d in degs])
        print(f"{cname:<22}  {deg_str}  {acc_val:>13.1f}  {notes}")

        col_avg_deg += degs
        col_avg_acc += acc_val

    avg_deg = col_avg_deg / max(num_classes, 1)
    avg_acc = col_avg_acc / max(num_classes, 1)
    avg_deg_str = "  ".join([f"{d:>16.3f}" for d in avg_deg])

    print(f"{'-'*110}")
    print(f"{'Average':<22}  {avg_deg_str}  {avg_acc:>13.2f}  Normalized Compute: {norm_compute:.2f}")
    print(f"{'='*110}")

    return avg_deg.tolist(), avg_acc, norm_compute


# ==========================
# Save Table 1 as CSV
# ==========================
def save_table1_csv(class_exp_degree, class_acc, class_cnt,
                    norm_compute, num_layers=3,
                    class_names=None,
                    path="./table1_class_expected_degrees.csv"):
    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    fieldnames = (["Class"]
                  + [f"Exp_Degree_L{l+1}" for l in range(num_layers)]
                  + ["Accuracy_pct", "N_samples", "Notes"])

    rows = []
    for ci, cname in enumerate(class_names):
        row = {"Class": cname}
        for l in range(num_layers):
            row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[ci, l]), 4)
        row["Accuracy_pct"] = round(float(class_acc[ci]), 2)
        row["N_samples"]    = int(class_cnt[ci])
        row["Notes"]        = generate_notes_default(cname, class_exp_degree[ci])
        rows.append(row)

    avg_row = {"Class": "Average"}
    for l in range(num_layers):
        avg_row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[:, l].mean()), 4)
    avg_row["Accuracy_pct"] = round(float(class_acc.mean()), 2)
    avg_row["N_samples"]    = int(class_cnt.sum())
    avg_row["Notes"]        = f"Normalized Compute: {norm_compute:.2f}"
    rows.append(avg_row)

    with open(path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)
    print(f"CSV saved: {path}")


# ==========================
# Train
# ==========================
def train_model(model, train_loader, val_loader, device,
                epochs=30, seed=42):
    set_seed(seed)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    criterion = nn.CrossEntropyLoss()

    best_f1    = -1.0
    best_state = None

    for ep in range(epochs):
        temp = cosine_temperature(ep, epochs, tmax=5.0, tmin=0.5)
        model.set_temperature(temp)

        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        model.eval()
        preds_all, labels_all = [], []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                preds_all.extend(model(X).argmax(1).cpu().numpy())
                labels_all.extend(y.cpu().numpy())
        val_f1 = f1_score(labels_all, preds_all, average="macro")

        if val_f1 > best_f1:
            best_f1    = val_f1
            best_state = copy.deepcopy(model.state_dict())

        if (ep + 1) % 5 == 0:
            print(f"  Epoch {ep+1:02d}/{epochs} | ValF1={val_f1:.4f} | Best={best_f1:.4f} | Temp={temp:.3f}")

    model.load_state_dict(best_state)
    print(f"  Training done. Best Val Macro-F1: {best_f1:.4f}")
    return model


# ==========================
# Main (PAMAP2)
# ==========================
if __name__ == "__main__":
    DATA_DIR    = "/content/drive/MyDrive/Colab Notebooks/HAR/har_orig_datasets"
    WINDOW_SIZE = 100
    STEP_SIZE   = 50

    BATCH_SIZE  = 64
    EPOCHS      = 30
    SEED        = 42
    NUM_WORKERS = 2 if USE_GPU else 0
    PIN_MEMORY  = USE_GPU

    NUM_LAYERS  = 3
    MAX_DEGREE  = 3

    # ----------------------------
    # Load PAMAP2 windows
    # ----------------------------
    set_seed(SEED)
    loader = PAMAP2Loader(DATA_DIR, window_size=WINDOW_SIZE, step_size=STEP_SIZE)
    X, y, subj_ids, LABEL_NAMES = loader.load()

    NUM_CLASSES = len(LABEL_NAMES)
    IN_CHANNELS = X.shape[1]

    idx = np.arange(len(y))
    train_idx, test_idx = train_test_split(
        idx, test_size=0.2, random_state=SEED, stratify=y
    )

    X_train = X[train_idx].copy()
    X_test  = X[test_idx].copy()

    y_train = y[train_idx]
    y_test  = y[test_idx]

    scaler = StandardScaler()

    Ntr, C, T = X_train.shape
    Xtr_2d = X_train.transpose(0, 2, 1).reshape(-1, C)
    scaler.fit(Xtr_2d)

    Xtr_2d = scaler.transform(Xtr_2d)
    X_train = Xtr_2d.reshape(Ntr, T, C).transpose(0, 2, 1)

    Nte = X_test.shape[0]
    Xte_2d = X_test.transpose(0, 2, 1).reshape(-1, C)
    Xte_2d = scaler.transform(Xte_2d)
    X_test = Xte_2d.reshape(Nte, T, C).transpose(0, 2, 1)

    train_ds = PAMAP2WindowDataset(X_train, y_train)
    val_ds   = PAMAP2WindowDataset(X_test,  y_test)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # ----------------------------
    # Full PADRe model
    # ----------------------------
    model = PADReHAR(
        in_channels=IN_CHANNELS,
        seq_len=WINDOW_SIZE,
        num_classes=NUM_CLASSES,
        hidden_dim=48,
        num_layers=NUM_LAYERS,
        max_degree=MAX_DEGREE,
        gate_hidden_dim=16,
        dropout=0.2,
        temperature_initial=5.0,
        temperature_min=0.5,
        gate_on=True,
        fixed_degree=1,
        use_ste=True,
        use_hadamard=True,
        use_ffn=True,
        use_residual=True,
    ).to(DEVICE)

    print(f"\nModel params: {count_parameters(model):,}")
    print("\n>>> Training Full PADRe on PAMAP2 (no Magne)...")
    model = train_model(model, train_loader, val_loader, DEVICE,
                        epochs=EPOCHS, seed=SEED)

    # -----------------------------------------------
    # Table 1: Clean
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [Clean]...")
    ced_clean, cacc_clean, ccnt, nc_clean = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=None
    )
    avg_deg_clean, avg_acc_clean, _ = print_table1(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=LABEL_NAMES, snr_label="Clean"
    )
    save_table1_csv(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=LABEL_NAMES,
        path="./table1_pamap2_clean.csv"
    )

    # -----------------------------------------------
    # Table 1: Noisy (SNR=10 dB)
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [SNR=10 dB]...")
    ced_noisy, cacc_noisy, _, nc_noisy = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=10
    )
    avg_deg_noisy, avg_acc_noisy, _ = print_table1(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=LABEL_NAMES, snr_label="SNR=10 dB"
    )
    save_table1_csv(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=LABEL_NAMES,
        path="./table1_pamap2_snr10.csv"
    )

    # -----------------------------------------------
    # Summary
    # -----------------------------------------------
    print("\n>>> Summary")
    print(f"  Clean  | Avg Acc={avg_acc_clean:.2f}% | Norm Compute={nc_clean:.4f}")
    print(f"  SNR=10 | Avg Acc={avg_acc_noisy:.2f}% | Norm Compute={nc_noisy:.4f}")
    print(f"  Robustness Drop (Acc): {avg_acc_clean - avg_acc_noisy:.2f}%")

Device: cuda | pin_memory: True


  df = df.groupby("subject_id", group_keys=False).apply(_fill_subject_group)


Loaded PAMAP2 (no Magne)
  X shape : (38862, 27, 100)  (N, C, T)
  y shape : (38862,)  (N,)
  Subjects: 9
  Classes : 12

Model params: 79,749

>>> Training Full PADRe on PAMAP2 (no Magne)...
  Epoch 05/30 | ValF1=0.9309 | Best=0.9309 | Temp=4.792
  Epoch 10/30 | ValF1=0.9466 | Best=0.9479 | Temp=4.013
  Epoch 15/30 | ValF1=0.9630 | Best=0.9630 | Temp=2.872
  Epoch 20/30 | ValF1=0.9678 | Best=0.9678 | Temp=1.696
  Epoch 25/30 | ValF1=0.9715 | Best=0.9715 | Temp=0.822
  Epoch 30/30 | ValF1=0.9726 | Best=0.9733 | Temp=0.500
  Training done. Best Val Macro-F1: 0.9733

>>> Collecting class-specific stats [Clean]...

  Table 1: Class-Specific Expected Degrees and Performance Metrics  [Clean]
Class                       Exp Deg (L1)      Exp Deg (L2)      Exp Deg (L3)   Accuracy (%)  Notes
--------------------------------------------------------------------------------------------------------------
Lying                              2.083             2.160             1.974           98.8  A

# MHEALTH

In [3]:
"""
PADRe HAR - Table 1: Class-Specific Expected Degrees and Performance Metrics (MHEALTH)
=============================================================================
출력 형식:
  Class | Exp Degree (L1) | Exp Degree (L2) | Exp Degree (L3) | Accuracy (%) | Notes

측정:
  - 클래스별 soft expected degree (per layer)
  - 클래스별 accuracy
  - Normalized Compute (전체 평균)

설정:
  - MHEALTH: acc + gyro only (ECG/Mag 제외)
  - 정규화: train windows 기준 fit, test는 transform만 (leakage-free)
"""

import os
import csv
import glob
import random
import copy
import re
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# ==========================
# Seed / Device
# ==========================
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


DEVICE  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_GPU = DEVICE.type == "cuda"
print(f"Device: {DEVICE} | pin_memory: {USE_GPU}")


# ==========================
# MHEALTH Loading (acc+gyro only)
# ==========================
MHEALTH_LABEL_NAMES = [
    "Standing still", "Sitting and relaxing", "Lying down",
    "Walking", "Climbing stairs", "Waist bends forward",
    "Frontal elevation of arms", "Knees bending", "Cycling",
    "Jogging", "Running", "Jump front & back",
]

def _load_single_mhealth_log(path: str, all_feature_cols: list[str]):
    df = pd.read_csv(
        path,
        sep="\t",
        header=None,
        names=all_feature_cols + ["label"],
    )
    return df

def load_mhealth_subject_dfs(data_dir: str):
    """
    Returns:
      dfs: list[pd.DataFrame] each has subject_id col
      all_feature_cols: list[str] (23 raw)
    """
    all_feature_cols = [
        "acc_chest_x", "acc_chest_y", "acc_chest_z",      # 0,1,2
        "ecg_1", "ecg_2",                                 # 3,4
        "acc_ankle_x", "acc_ankle_y", "acc_ankle_z",      # 5,6,7
        "gyro_ankle_x", "gyro_ankle_y", "gyro_ankle_z",   # 8,9,10
        "mag_ankle_x", "mag_ankle_y", "mag_ankle_z",      # 11,12,13
        "acc_arm_x", "acc_arm_y", "acc_arm_z",            # 14,15,16
        "gyro_arm_x", "gyro_arm_y", "gyro_arm_z",         # 17,18,19
        "mag_arm_x", "mag_arm_y", "mag_arm_z",            # 20,21,22
    ]

    log_files = glob.glob(os.path.join(data_dir, "mHealth_subject*.log"))
    if not log_files:
        raise FileNotFoundError(f"No mHealth_subject*.log files found in {data_dir}")
    print(f"Found {len(log_files)} log files in {data_dir}")

    dfs = []
    for fp in sorted(log_files):
        df_i = _load_single_mhealth_log(fp, all_feature_cols)

        # subject id 추출 (파일명에서 숫자)
        m = re.findall(r"\d+", os.path.basename(fp))
        subj_id = int(m[0]) if m else 0
        df_i["subject_id"] = subj_id

        # label==0 제거, label 1..12 -> 0..11
        df_i = df_i[df_i["label"] != 0].copy()
        df_i.loc[:, "label"] = df_i["label"].astype(np.int64) - 1

        dfs.append(df_i)

    return dfs, all_feature_cols

def create_mhealth_windows_per_subject(dfs, feature_cols, window_size, step_size):
    X_list, y_list, s_list = [], [], []

    for df in dfs:
        subj_id = int(df["subject_id"].iloc[0])
        data_arr   = df[feature_cols].to_numpy(dtype=np.float32)    # (L, C)
        labels_arr = df["label"].to_numpy(dtype=np.int64)           # (L,)
        L = data_arr.shape[0]

        start = 0
        while start + window_size <= L:
            end = start + window_size
            window_x = data_arr[start:end]                          # (T, C)
            window_label = int(labels_arr[end - 1])                 # last label
            window_x_ct = window_x.T                                # (C, T)

            X_list.append(window_x_ct)
            y_list.append(window_label)
            s_list.append(subj_id)
            start += step_size

    if not X_list:
        raise RuntimeError("No windows created. Check window/step or data format.")

    X = np.stack(X_list, axis=0).astype(np.float32)    # (N, C, T)
    y = np.asarray(y_list, dtype=np.int64)             # (N,)
    subj_ids = np.asarray(s_list, dtype=np.int64)      # (N,)
    return X, y, subj_ids


class MHEALTHWindowDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx], dtype=torch.long)


def normalize_train_only(X_train: np.ndarray, X_test: np.ndarray):
    scaler = StandardScaler()

    Ntr, C, T = X_train.shape
    Xtr_2d = X_train.transpose(0, 2, 1).reshape(-1, C)  # (Ntr*T, C)
    scaler.fit(Xtr_2d)

    Xtr_2d = scaler.transform(Xtr_2d)
    X_train_n = Xtr_2d.reshape(Ntr, T, C).transpose(0, 2, 1)

    Nte = X_test.shape[0]
    Xte_2d = X_test.transpose(0, 2, 1).reshape(-1, C)   # (Nte*T, C)
    Xte_2d = scaler.transform(Xte_2d)
    X_test_n = Xte_2d.reshape(Nte, T, C).transpose(0, 2, 1)

    return X_train_n.astype(np.float32), X_test_n.astype(np.float32)


# ==========================
# Gate
# ==========================
class ComputeAwareDegreeGate(nn.Module):
    def __init__(self, channels, max_degree=3, gate_hidden_dim=16,
                 temperature_initial=5.0, temperature_min=0.5):
        super().__init__()
        self.max_degree = max_degree
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(1),
            nn.Linear(channels, gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, max_degree),
        )
        nn.init.zeros_(self.gate[-1].bias)
        if max_degree >= 3:
            self.gate[-1].bias.data[1] = 0.4
        self.register_buffer("temperature", torch.tensor(float(temperature_initial)))
        self.temperature_min = float(temperature_min)

    def set_temperature(self, t):
        self.temperature.fill_(max(float(t), self.temperature_min))

    def forward(self, x, use_ste=True):
        logits     = self.gate(x)
        soft_probs = F.softmax(logits / self.temperature, dim=-1)
        if use_ste:
            if self.training:
                hard_idx    = logits.argmax(dim=-1)
                hard_oh     = F.one_hot(hard_idx, num_classes=self.max_degree).float()
                degree_w    = hard_oh - soft_probs.detach() + soft_probs
            else:
                degree_w = F.one_hot(logits.argmax(dim=-1), num_classes=self.max_degree).float()
        else:
            degree_w = soft_probs
        return degree_w, logits, soft_probs


# ==========================
# PADRe Block
# ==========================
class PADReBlock(nn.Module):
    def __init__(self, channels, seq_len, max_degree=3, kernel_size=11,
                 gate_hidden_dim=16, temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1, use_ste=True, use_hadamard=True):
        super().__init__()
        self.max_degree  = max_degree
        self.gate_on     = gate_on
        self.fixed_degree= fixed_degree
        self.use_ste     = use_ste
        self.use_hadamard= use_hadamard

        if gate_on:
            self.degree_gate = ComputeAwareDegreeGate(
                channels, max_degree=max_degree,
                gate_hidden_dim=gate_hidden_dim,
                temperature_initial=temperature_initial,
                temperature_min=temperature_min,
            )

        self.channel_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree)
        ])
        self.token_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=kernel_size,
                      padding=kernel_size // 2, groups=channels)
            for _ in range(max_degree)
        ])
        if use_hadamard:
            self.pre_hadamard_channel = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=1) for _ in range(max_degree - 1)
            ])
            self.pre_hadamard_token = nn.ModuleList([
                nn.Conv1d(channels, channels, kernel_size=kernel_size,
                          padding=kernel_size // 2, groups=channels)
                for _ in range(max_degree - 1)
            ])

        self.norm = nn.LayerNorm(channels)

    def _build_Z(self, x, max_deg):
        Y = [self.token_mixing[i](self.channel_mixing[i](x)) for i in range(max_deg)]
        Z = [Y[0]]
        for i in range(1, max_deg):
            if self.use_hadamard:
                z = self.pre_hadamard_token[i-1](self.pre_hadamard_channel[i-1](Z[-1]))
                Z.append(z * Y[i])
            else:
                Z.append(Z[-1] + Y[i])
        return Z

    def _hard_forward(self, x, sel):
        B = x.shape[0]
        max_deg = max(1, min(int(sel.max().item()) + 1, self.max_degree))
        Z = self._build_Z(x, max_deg)
        Z_stack = torch.stack(Z, dim=0)
        return Z_stack[sel, torch.arange(B, device=x.device)]

    def _soft_forward(self, x, w):
        B = x.shape[0]
        Z = self._build_Z(x, self.max_degree)
        Z_stack = torch.stack(Z, dim=1)              # (B,K,C,T)
        return (Z_stack * w.view(B, self.max_degree, 1, 1)).sum(dim=1)

    def forward(self, x, return_gate_info=False):
        B = x.shape[0]
        if self.gate_on:
            dw, logits, sp = self.degree_gate(x, use_ste=self.use_ste)
        else:
            fd = self.fixed_degree
            dw = F.one_hot(torch.full((B,), fd, dtype=torch.long, device=x.device),
                           num_classes=self.max_degree).float()
            logits = dw.clone(); sp = dw.clone()

        if not self.use_ste and self.gate_on:
            out = self._soft_forward(x, dw)
            sel = sp.argmax(dim=-1)
        else:
            sel = dw.argmax(dim=-1)
            out = self._hard_forward(x, sel)

        out = self.norm(out.permute(0, 2, 1)).permute(0, 2, 1)

        if return_gate_info:
            return out, {"degree_selection": dw, "soft_probs": sp,
                         "logits": logits,
                         "compute_cost": (sel + 1).float().mean().item()}
        return out


# ==========================
# Model
# ==========================
class PADReHAR(nn.Module):
    def __init__(self, in_channels=15, seq_len=128, num_classes=12,
                 hidden_dim=48, num_layers=3, max_degree=3,
                 gate_hidden_dim=16, dropout=0.2,
                 temperature_initial=5.0, temperature_min=0.5,
                 gate_on=True, fixed_degree=1,
                 use_ste=True, use_hadamard=True,
                 use_ffn=True, use_residual=True):
        super().__init__()
        self.num_layers   = num_layers
        self.max_degree   = max_degree
        self.gate_on      = gate_on
        self.use_ffn      = use_ffn
        self.use_residual = use_residual

        self.input_proj = nn.Conv1d(in_channels, hidden_dim, kernel_size=1)
        self.padre_blocks = nn.ModuleList([
            PADReBlock(hidden_dim, seq_len, max_degree=max_degree, kernel_size=11,
                       gate_hidden_dim=gate_hidden_dim,
                       temperature_initial=temperature_initial,
                       temperature_min=temperature_min,
                       gate_on=gate_on, fixed_degree=fixed_degree,
                       use_ste=use_ste, use_hadamard=use_hadamard)
            for _ in range(num_layers)
        ])

        if use_ffn:
            self.ffn = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=1),
                    nn.GELU(), nn.Dropout(dropout),
                    nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
                    nn.Dropout(dropout),
                )
                for _ in range(num_layers)
            ])

        self.norms1 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)]) if use_ffn else None
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier  = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim),
            nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )

    def set_temperature(self, t):
        for b in self.padre_blocks:
            if self.gate_on:
                b.degree_gate.set_temperature(t)

    def _ln(self, norm, x):
        return norm(x.permute(0, 2, 1)).permute(0, 2, 1)

    def forward(self, x, return_gate_info=False):
        x = self.input_proj(x)
        gate_info_list = [] if return_gate_info else None
        total_compute  = 0.0

        for i, block in enumerate(self.padre_blocks):
            res = x
            if return_gate_info:
                x, gi = block(x, return_gate_info=True)
                gate_info_list.append(gi); total_compute += gi["compute_cost"]
            else:
                x = block(x)

            x = self._ln(self.norms1[i], x + res if self.use_residual else x)

            if self.use_ffn:
                res2 = x
                x = self.ffn[i](x)
                x = self._ln(self.norms2[i], x + res2 if self.use_residual else x)

        logits = self.classifier(self.global_pool(x).squeeze(-1))
        return (logits, gate_info_list, total_compute) if return_gate_info else logits


# ==========================
# Utilities
# ==========================
def add_gaussian_noise(X, snr_db):
    sp  = (X ** 2).mean(dim=(1, 2), keepdim=True)
    snr = 10 ** (snr_db / 10.0)
    return X + torch.randn_like(X) * torch.sqrt(sp / snr)

def cosine_temperature(ep, total, tmax=5.0, tmin=0.5):
    r = ep / max(total - 1, 1)
    return tmin + (tmax - tmin) * 0.5 * (1.0 + np.cos(np.pi * r))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# ==========================
# Class-specific Expected Degree + Accuracy
# ==========================
@torch.no_grad()
def collect_class_stats(model, loader, device, num_classes=12, snr_db=None):
    model.eval()
    L = len(model.padre_blocks)
    K = model.max_degree

    class_deg_sum = np.zeros((num_classes, L), dtype=np.float64)
    class_correct = np.zeros(num_classes, dtype=np.int64)
    class_cnt     = np.zeros(num_classes, dtype=np.int64)

    deg_sum_hard  = np.zeros(L, dtype=np.float64)
    total_samples = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        if snr_db is not None:
            X = add_gaussian_noise(X, snr_db)

        logits, gate_info_list, _ = model(X, return_gate_info=True)
        preds  = logits.argmax(dim=1)
        y_cpu  = y.cpu().numpy()
        p_cpu  = preds.cpu().numpy()
        B      = X.shape[0]
        total_samples += B

        for c in range(num_classes):
            mask = (y_cpu == c)
            class_cnt[c]     += int(mask.sum())
            class_correct[c] += int((p_cpu[mask] == c).sum())

        for li, gi in enumerate(gate_info_list):
            sp      = gi["soft_probs"]                        # (B, K)
            deg_v   = torch.arange(1, K + 1, device=sp.device).float()
            exp_deg = (sp * deg_v).sum(dim=-1).cpu().numpy()  # (B,)

            for c in range(num_classes):
                mask = (y_cpu == c)
                if mask.any():
                    class_deg_sum[c, li] += exp_deg[mask].sum()

            sel = gi["degree_selection"].argmax(dim=-1)
            deg_sum_hard[li] += (sel.float() + 1.0).sum().item()

    class_exp_degree = class_deg_sum / np.maximum(class_cnt[:, None], 1)
    class_acc        = class_correct / np.maximum(class_cnt, 1) * 100.0

    deg_mean     = deg_sum_hard / max(total_samples, 1)
    norm_compute = float(deg_mean.sum()) / (L * K)

    return class_exp_degree, class_acc, class_cnt, norm_compute


# ==========================
# Notes / Table / CSV
# ==========================
def generate_notes_default(class_name, exp_degrees):
    avg = float(np.mean(exp_degrees))
    return f"Avg degree={avg:.3f}"

def print_table1(class_exp_degree, class_acc, class_cnt,
                 norm_compute, num_layers=3,
                 class_names=None, snr_label="Clean"):

    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    num_classes = len(class_names)
    L = num_layers

    layer_headers = "  ".join([f"{'Exp Deg (L'+str(l+1)+')':>16}" for l in range(L)])
    print(f"\n{'='*120}")
    print(f"  Table 1: Class-Specific Expected Degrees and Performance Metrics  [{snr_label}]")
    print(f"{'='*120}")
    print(f"{'Class':<28}  {layer_headers}  {'Accuracy (%)':>13}  Notes")
    print(f"{'-'*120}")

    col_avg_deg = np.zeros(L, dtype=np.float64)
    col_avg_acc = 0.0

    for ci, cname in enumerate(class_names):
        degs    = class_exp_degree[ci]
        acc_val = class_acc[ci]
        notes   = generate_notes_default(cname, degs)

        deg_str = "  ".join([f"{d:>16.3f}" for d in degs])
        print(f"{cname:<28}  {deg_str}  {acc_val:>13.1f}  {notes}")

        col_avg_deg += degs
        col_avg_acc += acc_val

    avg_deg = col_avg_deg / max(num_classes, 1)
    avg_acc = col_avg_acc / max(num_classes, 1)
    avg_deg_str = "  ".join([f"{d:>16.3f}" for d in avg_deg])

    print(f"{'-'*120}")
    print(f"{'Average':<28}  {avg_deg_str}  {avg_acc:>13.2f}  Normalized Compute: {norm_compute:.2f}")
    print(f"{'='*120}")

    return avg_deg.tolist(), avg_acc, norm_compute


def save_table1_csv(class_exp_degree, class_acc, class_cnt,
                    norm_compute, num_layers=3,
                    class_names=None,
                    path="/content/table1_mhealth_clean.csv"):
    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(class_acc))]

    fieldnames = (["Class"]
                  + [f"Exp_Degree_L{l+1}" for l in range(num_layers)]
                  + ["Accuracy_pct", "N_samples", "Notes"])

    rows = []
    for ci, cname in enumerate(class_names):
        row = {"Class": cname}
        for l in range(num_layers):
            row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[ci, l]), 4)
        row["Accuracy_pct"] = round(float(class_acc[ci]), 2)
        row["N_samples"]    = int(class_cnt[ci])
        row["Notes"]        = generate_notes_default(cname, class_exp_degree[ci])
        rows.append(row)

    avg_row = {"Class": "Average"}
    for l in range(num_layers):
        avg_row[f"Exp_Degree_L{l+1}"] = round(float(class_exp_degree[:, l].mean()), 4)
    avg_row["Accuracy_pct"] = round(float(class_acc.mean()), 2)
    avg_row["N_samples"]    = int(class_cnt.sum())
    avg_row["Notes"]        = f"Normalized Compute: {norm_compute:.2f}"
    rows.append(avg_row)

    with open(path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)
    print(f"CSV saved: {path}")


# ==========================
# Train
# ==========================
def train_model(model, train_loader, val_loader, device,
                epochs=30, seed=42):
    set_seed(seed)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    criterion = nn.CrossEntropyLoss()

    best_f1    = -1.0
    best_state = None

    for ep in range(epochs):
        temp = cosine_temperature(ep, epochs, tmax=5.0, tmin=0.5)
        model.set_temperature(temp)

        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        scheduler.step()

        model.eval()
        preds_all, labels_all = [], []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                preds_all.extend(model(X).argmax(1).cpu().numpy())
                labels_all.extend(y.cpu().numpy())
        val_f1 = f1_score(labels_all, preds_all, average="macro")

        if val_f1 > best_f1:
            best_f1    = val_f1
            best_state = copy.deepcopy(model.state_dict())

        if (ep + 1) % 5 == 0:
            print(f"  Epoch {ep+1:02d}/{epochs} | ValF1={val_f1:.4f} | Best={best_f1:.4f} | Temp={temp:.3f}")

    model.load_state_dict(best_state)
    print(f"  Training done. Best Val Macro-F1: {best_f1:.4f}")
    return model


# ==========================
# Main (MHEALTH)
# ==========================
if __name__ == "__main__":
    DATA_DIR    = "/content/drive/MyDrive/Colab Notebooks/HAR/har_orig_datasets/MHEALTHDATASET"
    WINDOW_SIZE = 100
    STEP_SIZE   = 50

    BATCH_SIZE  = 64
    EPOCHS      = 30
    SEED        = 42
    NUM_WORKERS = 2 if USE_GPU else 0
    PIN_MEMORY  = USE_GPU

    NUM_LAYERS  = 3
    NUM_CLASSES = 12

    # --------------------------
    # 1) load subject dfs
    # --------------------------
    set_seed(SEED)
    dfs, all_cols = load_mhealth_subject_dfs(DATA_DIR)

    # acc + gyro only (ECG/Mag 제외) => 15 channels
    feature_cols = [
        "acc_chest_x","acc_chest_y","acc_chest_z",
        "acc_ankle_x","acc_ankle_y","acc_ankle_z",
        "gyro_ankle_x","gyro_ankle_y","gyro_ankle_z",
        "acc_arm_x","acc_arm_y","acc_arm_z",
        "gyro_arm_x","gyro_arm_y","gyro_arm_z",
    ]

    # --------------------------
    # 2) windowing (subject boundary safe)
    # --------------------------
    X, y, subj_ids = create_mhealth_windows_per_subject(
        dfs, feature_cols, window_size=WINDOW_SIZE, step_size=STEP_SIZE
    )
    print("=" * 80)
    print("Loaded MHEALTH windows (acc+gyro only)")
    print(f"  X shape : {X.shape} (N, C, T)")
    print(f"  y shape : {y.shape} (N,)")
    print(f"  Classes : {len(MHEALTH_LABEL_NAMES)}")
    print("=" * 80)

    # --------------------------
    # 3) split 먼저 (윈도우 stratify)
    # --------------------------
    idx = np.arange(len(y))
    train_idx, test_idx = train_test_split(
        idx, test_size=0.2, random_state=SEED, stratify=y
    )

    X_train, y_train = X[train_idx].copy(), y[train_idx]
    X_test,  y_test  = X[test_idx].copy(),  y[test_idx]

    # --------------------------
    # 4) normalize train-only (fit on train, transform both)
    # --------------------------
    X_train, X_test = normalize_train_only(X_train, X_test)

    train_ds = MHEALTHWindowDataset(X_train, y_train)
    val_ds   = MHEALTHWindowDataset(X_test,  y_test)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    # --------------------------
    # 5) model
    # --------------------------
    model = PADReHAR(
        in_channels=X_train.shape[1],   # 15
        seq_len=WINDOW_SIZE,
        num_classes=NUM_CLASSES,
        hidden_dim=48,
        num_layers=NUM_LAYERS,
        max_degree=3,
        gate_hidden_dim=16,
        dropout=0.2,
        temperature_initial=5.0,
        temperature_min=0.5,
        gate_on=True,
        fixed_degree=1,
        use_ste=True,
        use_hadamard=True,
        use_ffn=True,
        use_residual=True,
    ).to(DEVICE)

    print(f"\nModel params: {count_parameters(model):,}")
    print("\n>>> Training Full PADRe on MHEALTH (acc+gyro only)...")
    model = train_model(model, train_loader, val_loader, DEVICE,
                        epochs=EPOCHS, seed=SEED)

    # -----------------------------------------------
    # Table 1: Clean
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [Clean]...")
    ced_clean, cacc_clean, ccnt, nc_clean = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=None
    )
    avg_deg_clean, avg_acc_clean, _ = print_table1(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=MHEALTH_LABEL_NAMES, snr_label="Clean"
    )
    save_table1_csv(
        ced_clean, cacc_clean, ccnt, nc_clean,
        num_layers=NUM_LAYERS, class_names=MHEALTH_LABEL_NAMES,
        path="/content/table1_mhealth_clean.csv"
    )

    # -----------------------------------------------
    # Table 1: Noisy (SNR=10 dB)
    # -----------------------------------------------
    print("\n>>> Collecting class-specific stats [SNR=10 dB]...")
    ced_noisy, cacc_noisy, _, nc_noisy = collect_class_stats(
        model, val_loader, DEVICE, num_classes=NUM_CLASSES, snr_db=10
    )
    avg_deg_noisy, avg_acc_noisy, _ = print_table1(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=MHEALTH_LABEL_NAMES, snr_label="SNR=10 dB"
    )
    save_table1_csv(
        ced_noisy, cacc_noisy, ccnt, nc_noisy,
        num_layers=NUM_LAYERS, class_names=MHEALTH_LABEL_NAMES,
        path="/content/table1_mhealth_snr10.csv"
    )

    # -----------------------------------------------
    # Summary
    # -----------------------------------------------
    print("\n>>> Summary")
    print(f"  Clean  | Avg Acc={avg_acc_clean:.2f}% | Norm Compute={nc_clean:.4f}")
    print(f"  SNR=10 | Avg Acc={avg_acc_noisy:.2f}% | Norm Compute={nc_noisy:.4f}")
    print(f"  Robustness Drop (Acc): {avg_acc_clean - avg_acc_noisy:.2f}%")

Device: cuda | pin_memory: True
Found 10 log files in /content/drive/MyDrive/Colab Notebooks/HAR/har_orig_datasets/MHEALTHDATASET
Loaded MHEALTH windows (acc+gyro only)
  X shape : (6849, 15, 100) (N, C, T)
  y shape : (6849,) (N,)
  Classes : 12

Model params: 79,173

>>> Training Full PADRe on MHEALTH (acc+gyro only)...
  Epoch 05/30 | ValF1=0.9870 | Best=0.9870 | Temp=4.792
  Epoch 10/30 | ValF1=0.9911 | Best=0.9911 | Temp=4.013
  Epoch 15/30 | ValF1=0.9897 | Best=0.9918 | Temp=2.872
  Epoch 20/30 | ValF1=0.9897 | Best=0.9918 | Temp=1.696
  Epoch 25/30 | ValF1=0.9911 | Best=0.9918 | Temp=0.822
  Epoch 30/30 | ValF1=0.9917 | Best=0.9924 | Temp=0.500
  Training done. Best Val Macro-F1: 0.9924

>>> Collecting class-specific stats [Clean]...

  Table 1: Class-Specific Expected Degrees and Performance Metrics  [Clean]
Class                             Exp Deg (L1)      Exp Deg (L2)      Exp Deg (L3)   Accuracy (%)  Notes
-------------------------------------------------------------------