In [1]:
import os
import random
import time
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix


# ---------------------------
# Seed / Device
# ---------------------------
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Seed: {seed}")


set_seed(42)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


# ---------------------------
# Dataset
# ---------------------------
class UCIHARDataset(Dataset):
    def __init__(self, data_dir, split="train"):
        self.data_dir = Path(data_dir)
        self.split = split
        self.X, self.y = self._load_data()
        self.X = torch.FloatTensor(self.X)               # (N, C, T)
        self.y = torch.LongTensor(self.y) - 1            # 1..6 -> 0..5

    def _load_data(self):
        split_dir = self.data_dir / self.split
        signal_types = [
            "body_acc_x", "body_acc_y", "body_acc_z",
            "body_gyro_x", "body_gyro_y", "body_gyro_z",
            "total_acc_x", "total_acc_y", "total_acc_z",
        ]
        signals = []
        for st in signal_types:
            fname = split_dir / "Inertial Signals" / f"{st}_{self.split}.txt"
            arr = np.loadtxt(fname)                      # (N, 128)
            signals.append(arr)
        X = np.stack(signals, axis=1)                    # (N, 9, 128)
        y = np.loadtxt(split_dir / f"y_{self.split}.txt", dtype=int)
        return X, y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# ---------------------------
# Gate
# ---------------------------
class ComputeAwareDegreeGate(nn.Module):
    def __init__(self, channels, max_degree=3, gate_hidden_dim=16,
                 temperature_initial=5.0, temperature_min=0.5):
        super().__init__()
        self.max_degree = max_degree
        self.gate = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(1),
            nn.Linear(channels, gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, max_degree),
        )
        nn.init.zeros_(self.gate[-1].bias)
        if max_degree >= 3:
            try:
                self.gate[-1].bias.data[1] = 0.4
            except Exception:
                pass

        self.register_buffer("temperature", torch.tensor(float(temperature_initial)))
        self.temperature_min = float(temperature_min)

    def set_temperature(self, temperature):
        self.temperature.fill_(max(float(temperature), self.temperature_min))

    def forward(self, x):
        logits = self.gate(x)  # (B, K)
        soft_probs = F.softmax(logits / self.temperature, dim=-1)

        if self.training:
            hard_idx = logits.argmax(dim=-1)
            hard_one_hot = F.one_hot(hard_idx, num_classes=self.max_degree).float()
            degree_selection = hard_one_hot - soft_probs.detach() + soft_probs
        else:
            degree_selection = F.one_hot(logits.argmax(dim=-1), num_classes=self.max_degree).float()

        return degree_selection, logits, soft_probs


# ---------------------------
# PADRe Block
# ---------------------------
class TrueComputeBudgetedPADReBlock_Slim(nn.Module):
    def __init__(self, channels, seq_len, max_degree=3, kernel_size=11, gate_hidden_dim=16,
                 temperature_initial=5.0, temperature_min=0.5):
        super().__init__()
        self.channels = channels
        self.seq_len = seq_len
        self.max_degree = max_degree
        self.kernel_size = kernel_size

        self.degree_gate = ComputeAwareDegreeGate(
            channels,
            max_degree=max_degree,
            gate_hidden_dim=gate_hidden_dim,
            temperature_initial=temperature_initial,
            temperature_min=temperature_min,
        )

        self.channel_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1)
            for _ in range(max_degree)
        ])
        self.token_mixing = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=kernel_size,
                      padding=kernel_size // 2, groups=channels)
            for _ in range(max_degree)
        ])

        self.pre_hadamard_channel = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=1)
            for _ in range(max_degree - 1)
        ])
        self.pre_hadamard_token = nn.ModuleList([
            nn.Conv1d(channels, channels, kernel_size=kernel_size,
                      padding=kernel_size // 2, groups=channels)
            for _ in range(max_degree - 1)
        ])

        self.norm = nn.LayerNorm([channels, seq_len])
        self.degree_usage_stats = defaultdict(float)
        self.num_forward_calls = 0

    def _compute_prefix(self, x, selected_degrees):
        B, C, T = x.shape
        max_degree_needed = int(selected_degrees.max().item()) + 1
        max_degree_needed = max(1, min(max_degree_needed, self.max_degree))

        Y = []
        for i in range(max_degree_needed):
            y = self.channel_mixing[i](x)
            y = self.token_mixing[i](y)
            Y.append(y)

        Z = [Y[0]]
        for i in range(1, max_degree_needed):
            z_prev = self.pre_hadamard_channel[i - 1](Z[-1])
            z_prev = self.pre_hadamard_token[i - 1](z_prev)
            z_curr = z_prev * Y[i]
            Z.append(z_curr)

        Z_stack = torch.stack(Z, dim=0)                  # (D, B, C, T)
        idx_b = torch.arange(B, device=x.device)
        output = Z_stack[selected_degrees, idx_b]        # (B, C, T)
        return output

    def forward(self, x, return_gate_info=False):
        degree_selection, logits, soft_probs = self.degree_gate(x)
        selected_degrees = degree_selection.argmax(dim=-1)

        out = self._compute_prefix(x, selected_degrees)
        out = self.norm(out)

        with torch.no_grad():
            for deg in selected_degrees:
                self.degree_usage_stats[int(deg.item())] += 1.0
            self.num_forward_calls += x.shape[0]

        if return_gate_info:
            return out, {
                "degree_selection": degree_selection,
                "soft_probs": soft_probs,
                "logits": logits,
                "compute_cost": (selected_degrees + 1).float().mean().item(),
            }
        return out

    def get_degree_usage(self):
        if self.num_forward_calls == 0:
            return {i: 0.0 for i in range(self.max_degree)}
        return {deg: usage / self.num_forward_calls for deg, usage in self.degree_usage_stats.items()}


# ---------------------------
# Model
# ---------------------------
class ComputeBudgetedPADReHAR_Slim(nn.Module):
    def __init__(self,
                 in_channels=9,
                 seq_len=128,
                 num_classes=6,
                 hidden_dim=48,
                 num_layers=3,
                 max_degree=3,
                 gate_hidden_dim=16,
                 dropout=0.2,
                 temperature_initial=5.0,
                 temperature_min=0.5):
        super().__init__()
        self.in_channels = in_channels
        self.seq_len = seq_len
        self.num_classes = num_classes
        self.num_layers = num_layers
        self.max_degree = max_degree

        self.input_proj = nn.Conv1d(in_channels, hidden_dim, kernel_size=1)

        self.padre_blocks = nn.ModuleList([
            TrueComputeBudgetedPADReBlock_Slim(
                hidden_dim, seq_len,
                max_degree=max_degree,
                kernel_size=11,
                gate_hidden_dim=gate_hidden_dim,
                temperature_initial=temperature_initial,
                temperature_min=temperature_min,
            )
            for _ in range(num_layers)
        ])

        self.ffn = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=1),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=1),
                nn.Dropout(dropout),
            )
            for _ in range(num_layers)
        ])

        self.norms1 = nn.ModuleList([nn.LayerNorm([hidden_dim, seq_len]) for _ in range(num_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm([hidden_dim, seq_len]) for _ in range(num_layers)])

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes),
        )

    def set_temperature(self, temperature):
        for block in self.padre_blocks:
            block.degree_gate.set_temperature(temperature)

    def forward(self, x, return_gate_info=False):
        x = self.input_proj(x)                           # (B, H, T)

        gate_info_list = [] if return_gate_info else None
        total_compute_cost = 0.0

        for padre_block, ffn, norm1, norm2 in zip(self.padre_blocks, self.ffn, self.norms1, self.norms2):
            residual = x
            if return_gate_info:
                x, gate_info = padre_block(x, return_gate_info=True)
                gate_info_list.append(gate_info)
                total_compute_cost += gate_info["compute_cost"]
            else:
                x = padre_block(x)
            x = norm1(x + residual)

            residual = x
            x = ffn(x)
            x = norm2(x + residual)

        feat = self.global_pool(x).squeeze(-1)
        logits = self.classifier(feat)

        if return_gate_info:
            return logits, gate_info_list, total_compute_cost
        return logits


# ---------------------------
# Corruptions
# ---------------------------
def add_gaussian_noise(X, snr_db):
    signal_power = (X ** 2).mean(dim=(1, 2), keepdim=True)
    snr = 10 ** (snr_db / 10.0)
    noise_power = signal_power / snr
    noise = torch.randn_like(X) * torch.sqrt(noise_power)
    return X + noise


def add_bias_drift(X, beta, mode="constant", per_channel=True):
    B, C, T = X.shape
    device = X.device

    if mode == "constant":
        if per_channel:
            b = torch.randn(B, C, 1, device=device) * beta
        else:
            b = torch.randn(B, 1, 1, device=device) * beta
        return X + b

    if mode == "brownian":
        if per_channel:
            eps = torch.randn(B, C, T, device=device) * beta
        else:
            eps = torch.randn(B, 1, T, device=device) * beta
        drift = torch.cumsum(eps, dim=-1)
        return X + drift

    raise ValueError("mode must be one of ['constant','brownian']")


@torch.no_grad()
def expected_degree_from_gateinfo(gate_info_list):
    layer_exp = []
    for gi in gate_info_list:
        sp = gi["soft_probs"]                            # (B,K)
        K = sp.shape[1]
        deg_vals = torch.arange(1, K + 1, device=sp.device).float()
        exp = (sp * deg_vals).sum(dim=-1).mean().item()
        layer_exp.append(exp)
    return layer_exp, float(np.mean(layer_exp))


# ---------------------------
# Robust Evaluator
# ---------------------------
class RobustEvaluator:
    def __init__(self, trainer):
        self.trainer = trainer

    @torch.no_grad()
    def eval_once(self, snr_db=None, bias_beta=None, bias_mode="constant", measure_gate=True):
        model = self.trainer.model
        model.eval()

        all_preds, all_labels = [], []
        layer_sum = None
        layer_count = 0

        for X, y in self.trainer.val_loader:
            X = X.to(self.trainer.device)
            y = y.to(self.trainer.device)

            if snr_db is not None:
                X = add_gaussian_noise(X, snr_db)
            if bias_beta is not None:
                X = add_bias_drift(X, beta=bias_beta, mode=bias_mode, per_channel=True)

            if measure_gate:
                logits, gate_info_list, _ = model(X, return_gate_info=True)
                layer_exp, _ = expected_degree_from_gateinfo(gate_info_list)
                if layer_sum is None:
                    layer_sum = np.zeros(len(layer_exp), dtype=np.float64)
                layer_sum += np.array(layer_exp, dtype=np.float64)
                layer_count += 1
            else:
                logits = model(X)

            preds = logits.argmax(dim=1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(y.detach().cpu().numpy())

        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)

        acc = accuracy_score(all_labels, all_preds)
        macro_f1 = f1_score(all_labels, all_preds, average="macro")
        per_class_f1 = f1_score(all_labels, all_preds, average=None)
        cm = confusion_matrix(all_labels, all_preds)

        gate_layer_exp = None
        if measure_gate and layer_count > 0:
            gate_layer_exp = (layer_sum / layer_count).tolist()

        return {
            "acc": acc,
            "macro_f1": macro_f1,
            "per_class_f1": per_class_f1,
            "cm": cm,
            "gate_layer_expected_degree": gate_layer_exp,
        }

    def instant_noise_sweep(self, snr_list_db):
        results = []
        for snr in snr_list_db:
            out = self.eval_once(snr_db=snr, measure_gate=True)
            out["snr_db"] = ("inf" if snr is None else snr)
            results.append(out)
        return results

    def linear_snr_drift(self, snr_start=30, snr_end=0, steps=10):
        snrs = np.linspace(snr_start, snr_end, steps).tolist()
        time_series = []
        for t, snr in enumerate(snrs):
            out = self.eval_once(snr_db=float(snr), measure_gate=True)
            time_series.append({
                "t": t,
                "snr_db": float(snr),
                "acc": out["acc"],
                "macro_f1": out["macro_f1"],
                "gate_layer_expected_degree": out["gate_layer_expected_degree"],
            })
        return time_series

    def bias_drift_sweep(self, beta_list, mode="constant"):
        results = []
        for beta in beta_list:
            out = self.eval_once(bias_beta=beta, bias_mode=mode, measure_gate=True)
            out["bias_beta"] = beta
            out["bias_mode"] = mode
            results.append(out)
        return results


# ---------------------------
# Trainer
# ---------------------------
class SimpleTrainer:
    def __init__(self, model, train_loader, val_loader, device="cuda"):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=50, eta_min=1e-5)

    def train_epoch(self):
        self.model.train()
        all_preds, all_labels = [], []
        total_loss = 0.0

        for X, y in self.train_loader:
            X = X.to(self.device)
            y = y.to(self.device)

            self.optimizer.zero_grad()
            logits = self.model(X)
            loss = self.criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            preds = logits.argmax(dim=1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(y.detach().cpu().numpy())
            total_loss += loss.item()

        acc = accuracy_score(all_labels, all_preds)
        self.scheduler.step()
        return total_loss / max(1, len(self.train_loader)), acc

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        all_preds, all_labels = [], []

        for X, y in self.val_loader:
            X = X.to(self.device)
            y = y.to(self.device)
            logits = self.model(X)
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(y.detach().cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        macro_f1 = f1_score(all_labels, all_preds, average="macro")
        return acc, macro_f1


# ---------------------------
# Gate stats (FIXED): degree histogram + classwise expected degree
# ---------------------------
@torch.no_grad()
def collect_gate_stats(
    model,
    loader,
    device,
    num_classes=6,
    max_degree=3,
    snr_db=None,
    bias_beta=None,
    bias_mode="constant",
):
    """
    Returns:
      hist_counts: np.ndarray shape (L, K)  # per-layer degree selection counts
      hist_ratio : np.ndarray shape (L, K)  # per-layer selection ratios
      class_exp  : np.ndarray shape (num_classes, L)  # per-class expected degree (avg over samples)
      class_cnt  : np.ndarray shape (num_classes,)    # sample counts per class

    NOTE (FIX):
      class_cnt must be counted ONCE per sample (per batch), not once per layer.
    """
    model.eval()

    L = len(model.padre_blocks)  # num_layers
    K = max_degree

    hist_counts = np.zeros((L, K), dtype=np.int64)
    class_sum = np.zeros((num_classes, L), dtype=np.float64)
    class_cnt = np.zeros((num_classes,), dtype=np.int64)

    for X, y in loader:
        X = X.to(device)
        y = y.to(device)

        if snr_db is not None:
            X = add_gaussian_noise(X, snr_db)
        if bias_beta is not None:
            X = add_bias_drift(X, beta=bias_beta, mode=bias_mode, per_channel=True)

        _, gate_info_list, _ = model(X, return_gate_info=True)

        # --- FIX: count samples per class ONCE per batch ---
        y_cpu = y.detach().cpu().numpy()
        for c in range(num_classes):
            class_cnt[c] += int((y_cpu == c).sum())

        # per-layer accumulation
        for li, gi in enumerate(gate_info_list):
            deg_sel = gi["degree_selection"].argmax(dim=-1)  # 0..K-1
            bc = torch.bincount(deg_sel, minlength=K).detach().cpu().numpy()
            hist_counts[li] += bc

            sp = gi["soft_probs"]  # (B,K)
            deg_vals = torch.arange(1, K + 1, device=sp.device).float()
            exp_deg = (sp * deg_vals).sum(dim=-1)  # (B,)

            exp_cpu = exp_deg.detach().cpu().numpy()
            for c in range(num_classes):
                mask = (y_cpu == c)
                if mask.any():
                    class_sum[c, li] += exp_cpu[mask].sum()

    hist_ratio = hist_counts / np.maximum(hist_counts.sum(axis=1, keepdims=True), 1)
    class_exp = class_sum / np.maximum(class_cnt[:, None], 1)

    return hist_counts, hist_ratio, class_exp, class_cnt


@torch.no_grad()
def eval_snr_suite_report(
    model,
    loader,
    device,
    snr_list_db,
    num_classes=6,
    max_degree=3,
    class_names=None,
):
    """
    For each SNR:
      - Acc / Macro-F1
      - Hard avg degree (per-layer) and total hard compute (= sum over layers)
      - Per-layer degree histogram (hard selections)

    Notes:
      - degree index: 0..K-1  => degree = 1..K (we report as degree=1..K)
      - "Hard" means argmax(degree_selection), not soft_probs expectation.
    """
    if class_names is None:
        class_names = [f"class_{i}" for i in range(num_classes)]

    model.eval()
    L = len(model.padre_blocks)
    K = max_degree

    print("\n==============================")
    print("[SNR Suite Report]")
    print("==============================")

    for snr_db in snr_list_db:
        # accumulators
        all_preds, all_labels = [], []
        hist_counts = np.zeros((L, K), dtype=np.int64)
        deg_sum = np.zeros((L,), dtype=np.float64)
        sample_count = 0

        for X, y in loader:
            X = X.to(device)
            y = y.to(device)

            # apply noise
            if snr_db is not None:
                X = add_gaussian_noise(X, float(snr_db))

            logits, gate_info_list, _ = model(X, return_gate_info=True)

            preds = logits.argmax(dim=1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(y.detach().cpu().numpy())

            B = X.shape[0]
            sample_count += B

            # per-layer hard degree stats
            for li, gi in enumerate(gate_info_list):
                # hard selection: 0..K-1
                deg_idx = gi["degree_selection"].argmax(dim=-1)  # (B,)
                bc = torch.bincount(deg_idx, minlength=K).detach().cpu().numpy()
                hist_counts[li] += bc

                # hard degree (1..K)
                deg_sum[li] += (deg_idx.float() + 1.0).sum().item()

        # metrics
        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)
        acc = accuracy_score(all_labels, all_preds)
        macro_f1 = f1_score(all_labels, all_preds, average="macro")

        # hard avg degree
        deg_mean_per_layer = (deg_sum / max(sample_count, 1)).tolist()
        hard_total_compute = float(np.sum(deg_mean_per_layer))  # sum over layers (like total cost)

        # print header
        snr_str = "inf" if snr_db is None else f"{snr_db:g}"
        print("\n--------------------------------")
        print(f"SNR={snr_str} dB")
        print(f"  Acc={acc:.4f} | Macro-F1={macro_f1:.4f}")
        print(f"  HardAvgDegree(per-layer)={[round(x,4) for x in deg_mean_per_layer]}")
        print(f"  HardTotalCompute(sum layers)={hard_total_compute:.4f}")

        # histogram ratios
        hist_ratio = hist_counts / np.maximum(hist_counts.sum(axis=1, keepdims=True), 1)

        print("  Degree histogram per layer (hard)")
        for li in range(L):
            cnt = hist_counts[li].tolist()
            rat = [f"{x:.3f}" for x in hist_ratio[li].tolist()]
            print(f"    Layer {li}: counts={cnt} | ratio={rat} (deg=1..{K})")


def print_gate_stats(title, hist_counts, hist_ratio, class_exp, class_cnt, class_names=None):
    L, K = hist_counts.shape
    if class_names is None:
        class_names = [f"class_{i}" for i in range(class_exp.shape[0])]

    print(f"\n==============================")
    print(f"[Gate Stats] {title}")
    print(f"==============================")

    print("\n(1) Degree histogram per layer")
    for li in range(L):
        cnt = hist_counts[li].tolist()
        rat = [f"{x:.3f}" for x in hist_ratio[li].tolist()]
        print(f"  Layer {li}: counts={cnt} | ratio={rat}  (deg=1..{K})")

    print("\n(2) Class-wise expected degree (per layer)")
    for ci, name in enumerate(class_names):
        if class_cnt[ci] == 0:
            continue
        vals = [f"{v:.3f}" for v in class_exp[ci].tolist()]
        print(f"  {name:>12s} (n={class_cnt[ci]}): {vals}")


# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/HAR/DONE/har_orig_datasets/UCI_HAR"
    BATCH_SIZE = 64
    EPOCHS = 30

    train_ds = UCIHARDataset(DATA_DIR, split="train")
    val_ds = UCIHARDataset(DATA_DIR, split="test")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = ComputeBudgetedPADReHAR_Slim(
        in_channels=9, seq_len=128, num_classes=6,
        hidden_dim=48, num_layers=3, max_degree=3
    ).to(DEVICE)

    trainer = SimpleTrainer(model, train_loader, val_loader, device=DEVICE)

    best_f1 = -1.0
    for ep in range(EPOCHS):
        t0 = time.time()
        train_loss, train_acc = trainer.train_epoch()
        val_acc, val_f1 = trainer.evaluate()
        t1 = time.time()

        print(f"Epoch {ep+1:03d}/{EPOCHS} | Train Loss {train_loss:.4f} Acc {train_acc:.4f} | "
              f"Val Acc {val_acc:.4f} Macro-F1 {val_f1:.4f} | {t1-t0:.1f}s")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(trainer.model.state_dict(), "best_padre_slim.pt")

    print("Done. Best Val Macro-F1:", best_f1)

    print("\n==============================")
    print("Robustness Experiments Start")
    print("==============================")

    # 1) Load best
    trainer.model.load_state_dict(torch.load("best_padre_slim.pt", map_location=DEVICE))
    trainer.model.eval()

    # 2) SNR list (include 0)
    snr_list = [None, 30, 20, 10, 5, 0]

    # 3) "Hard" report: Acc/Macro-F1 + Hard degree + histogram  ✅(너가 원하는 핵심)
    eval_snr_suite_report(
        trainer.model,
        val_loader,
        DEVICE,
        snr_list_db=snr_list,
        num_classes=6,
        max_degree=3,
        class_names=["WALK", "UP", "DOWN", "SIT", "STAND", "LAY"],
    )

    # 4) Optional: 기존 gate 통계(soft expected degree + classwise expected degree)
    CLASS_NAMES = ["WALK", "UP", "DOWN", "SIT", "STAND", "LAY"]

    hc, hr, ce, cc = collect_gate_stats(
        trainer.model, val_loader, DEVICE,
        num_classes=6, max_degree=3,
        snr_db=None, bias_beta=None
    )
    print_gate_stats("CLEAN (soft exp + classwise)", hc, hr, ce, cc, class_names=CLASS_NAMES)

    hc, hr, ce, cc = collect_gate_stats(
        trainer.model, val_loader, DEVICE,
        num_classes=6, max_degree=3,
        snr_db=10, bias_beta=None
    )
    print_gate_stats("GAUSSIAN NOISE (SNR=10) (soft exp + classwise)", hc, hr, ce, cc, class_names=CLASS_NAMES)

    # 5) RobustEvaluator (표에 넣을 숫자용)  ✅0 포함해서 통일
    evaluator = RobustEvaluator(trainer)

    A = evaluator.instant_noise_sweep(snr_list)
    print("\n[A] Instant noise sweep (SNR) [SOFT expected degree]")
    for r in A:
        print(f"  SNR={r['snr_db']} | Acc={r['acc']:.4f} | Macro-F1={r['macro_f1']:.4f} | "
              f"GateExp(per-layer)={r['gate_layer_expected_degree']}")

    B = evaluator.linear_snr_drift(snr_start=30, snr_end=0, steps=10)
    print("\n[B] Linear SNR drift over time")
    for row in B:
        print(f"  t={row['t']:02d} | SNR={row['snr_db']:.1f} | Acc={row['acc']:.4f} | "
              f"Macro-F1={row['macro_f1']:.4f} | GateExp={row['gate_layer_expected_degree']}")

    beta_list = [0.05, 0.10, 0.20]
    C = evaluator.bias_drift_sweep(beta_list, mode="constant")
    print("\n[C] Bias drift sweep")
    for r in C:
        print(f"  beta={r['bias_beta']:.2f} ({r['bias_mode']}) | Acc={r['acc']:.4f} | "
              f"Macro-F1={r['macro_f1']:.4f} | GateExp={r['gate_layer_expected_degree']}")

Seed: 42
Device: cuda
Epoch 001/30 | Train Loss 0.7401 Acc 0.7724 | Val Acc 0.8979 Macro-F1 0.8980 | 3.5s
Epoch 002/30 | Train Loss 0.2036 Acc 0.9368 | Val Acc 0.9203 Macro-F1 0.9195 | 2.9s
Epoch 003/30 | Train Loss 0.1520 Acc 0.9490 | Val Acc 0.9138 Macro-F1 0.9133 | 3.2s
Epoch 004/30 | Train Loss 0.1347 Acc 0.9471 | Val Acc 0.9169 Macro-F1 0.9166 | 2.5s
Epoch 005/30 | Train Loss 0.1296 Acc 0.9501 | Val Acc 0.9057 Macro-F1 0.9066 | 2.9s
Epoch 006/30 | Train Loss 0.1263 Acc 0.9509 | Val Acc 0.9199 Macro-F1 0.9205 | 2.6s
Epoch 007/30 | Train Loss 0.1147 Acc 0.9532 | Val Acc 0.9304 Macro-F1 0.9304 | 3.3s
Epoch 008/30 | Train Loss 0.1211 Acc 0.9528 | Val Acc 0.9226 Macro-F1 0.9232 | 2.9s
Epoch 009/30 | Train Loss 0.1118 Acc 0.9543 | Val Acc 0.9040 Macro-F1 0.9046 | 2.7s
Epoch 010/30 | Train Loss 0.1122 Acc 0.9535 | Val Acc 0.9125 Macro-F1 0.9136 | 2.6s
Epoch 011/30 | Train Loss 0.1033 Acc 0.9559 | Val Acc 0.9169 Macro-F1 0.9175 | 2.6s
Epoch 012/30 | Train Loss 0.0990 Acc 0.9618 | Val Acc 