In [None]:
# ============================================================
# UCI-HAR + ELK Backbone + Sequential Cross-Attention + Temporal Prototype Attention
# Modified: ModernTCN → ELKBlock with Structural Reparameterization
# ============================================================
import os, math, time, warnings
warnings.filterwarnings("ignore")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

# ---------------------------
# 0) UCI-HAR Loader (변경 없음)
# ---------------------------
_UCI_CHANNELS = [
    ("Inertial Signals/total_acc_x_",  "txt"),
    ("Inertial Signals/total_acc_y_",  "txt"),
    ("Inertial Signals/total_acc_z_",  "txt"),
    ("Inertial Signals/body_acc_x_",   "txt"),
    ("Inertial Signals/body_acc_y_",   "txt"),
    ("Inertial Signals/body_acc_z_",   "txt"),
    ("Inertial Signals/body_gyro_x_",  "txt"),
    ("Inertial Signals/body_gyro_y_",  "txt"),
    ("Inertial Signals/body_gyro_z_",  "txt"),
]

def _read_txt_matrix(path: str) -> np.ndarray:
    ".txt를 Numpy array로 읽는 헬퍼"
    return np.loadtxt(path)

def _load_ucihar_split(root: str, split: str):
    "train, test 불러오는 함수"
    assert split in ("train", "test")
    split_dir = os.path.join(root, split)
    mats = []
    for base, ext in _UCI_CHANNELS:
        arr = _read_txt_matrix(os.path.join(split_dir, f"{base}{split}.{ext}"))
        mats.append(arr[:, None, :])
    X = np.concatenate(mats, axis=1).astype(np.float32)
    y = np.loadtxt(os.path.join(split_dir, f"y_{split}.txt")).astype(np.int64) - 1
    return X, y

def fit_channel_stats(X: np.ndarray):
    "평균, 표준편차 계산 함수"
    mu = X.mean(axis=(0, 2), keepdims=True)
    sd = X.std(axis=(0, 2), keepdims=True)
    sd[sd < 1e-6] = 1.0
    return mu.astype(np.float32), sd.astype(np.float32)

class UCIHARDataset(Dataset):
    def __init__(self, root: str, split: str, stats=None):
        self.X, self.y = _load_ucihar_split(root, split)  # 데이터 로드
        self.stats = stats  # 정규화 통계 (평균, 표준편차)

    def set_stats(self, stats):
        "외부에서 계산된 통계(주로 훈련셋의 통계)를 설정하는 함수"
        self.stats = stats

    def __len__(self): return len(self.X)  # 전체 샘플 수

    def __getitem__(self, idx):
        x = self.X[idx]  # idx번째 샘플 (9, 128)
        y = self.y[idx]  # idx번째 라벨 (스칼라 값)

        # 만약 정규화 통계가 설정되어 있다면
        if self.stats is not None:
            mu, sd = self.stats
            # (9, 128) - (9, 1) / (9, 1) -> 브로드캐스팅으로 정규화 수행
            x = (x - mu.squeeze(0)) / sd.squeeze(0)
        return torch.from_numpy(x).float(), torch.tensor(y, dtype=torch.long)

# ---------------------------
# 1) ELK Backbone (신규 추가)
# ---------------------------
class ELKBlock(nn.Module):
    """
    Efficient Large Kernel Block with structural reparameterization.
    1. DataLoader에서 (B, 9, 128) tensor 'x'가 입력됨.
    2. 
    """
    def __init__(self, in_channels, out_channels, kernel_size=31, deploy=False):
        super().__init__()
        self.deploy = deploy  # deploy=True이면 '추론' 모드, False이면 '훈련' 모드
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size  # 큰 커널 크기 (예: 31)

        # Calculate paddings (5개 브랜치에 필요한 커널 크기와 패딩 미리 계산)
        padding_large1 = kernel_size // 2
        kernel_size_large2 = kernel_size - 2
        padding_large2 = kernel_size_large2 // 2
        kernel_size_small1 = 5
        padding_small1 = kernel_size_small1 // 2
        kernel_size_small2 = 3
        padding_small2 = kernel_size_small2 // 2

        if deploy:
            # 🚀 추론 모드
            # 5개 브랜치가 합쳐진 '단 하나의' Conv 레이어만 정의
            self.reparam_conv = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=True
            )
        else:
            # 🧠 훈련 모드(deploy=False):
            # 5개의 병렬 브랜치를 모두 정의
            self.dw_large1 = nn.Conv1d(
                in_channels, in_channels, kernel_size,
                padding=padding_large1, groups=in_channels, bias=False
            )
            self.bn_large1 = nn.BatchNorm1d(in_channels)

            self.dw_large2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_large2,
                padding=padding_large2, groups=in_channels, bias=False
            )
            self.bn_large2 = nn.BatchNorm1d(in_channels)

            self.dw_small1 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small1,
                padding=padding_small1, groups=in_channels, bias=False
            )
            self.bn_small1 = nn.BatchNorm1d(in_channels)

            self.dw_small2 = nn.Conv1d(
                in_channels, in_channels, kernel_size_small2,
                padding=padding_small2, groups=in_channels, bias=False
            )
            self.bn_small2 = nn.BatchNorm1d(in_channels)

            self.bn_id = nn.BatchNorm1d(in_channels)

        # 5개 브랜치의 출력이 합쳐진 후, 공통으로 통과하는 1x1 Pointwise Conv
        self.pointwise = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_channels),
        )
        self.activation = nn.GELU() # ReLU 대신 GELU 사용

    def forward(self, x):
        if self.deploy:
            x = self.reparam_conv(x)
        else:
            x1 = self.bn_large1(self.dw_large1(x))
            x2 = self.bn_large2(self.dw_large2(x))
            x3 = self.bn_small1(self.dw_small1(x))
            x4 = self.bn_small2(self.dw_small2(x))
            x5 = self.bn_id(x)
            x = x1 + x2 + x3 + x4 + x5

        x = self.activation(x)
        return self.pointwise(x)

    # 훈련 -> 추론 모드로 '변신'하는 함수
    def reparameterize(self):
        if self.deploy:  # 이미 추론 모드면 아무것도 안 함
            return  

        # Conv + BN을 수학적으로 합쳐(fuse)주는 헬퍼 함수
        def _fuse(conv, bn):
            if conv is None:
                # Identity 브랜치 (bn_id) 처리
                # k=31짜리 Conv지만, 중앙에만 '1'이 있는 껍데기 커널 생성
                kernel = torch.zeros(
                    (self.in_channels, 1, self.kernel_size),
                    dtype=bn.weight.dtype, device=bn.weight.device
                )
                center = self.kernel_size // 2
                kernel[:, 0, center] = 1.0
                conv_bias = torch.zeros(
                    self.in_channels, dtype=bn.weight.dtype, device=bn.weight.device
                )
            else:
                # 일반 Conv 브랜치
                kernel = conv.weight
                conv_bias = torch.zeros(
                    self.in_channels, dtype=bn.weight.dtype, device=bn.weight.device
                )

            # BN 파라미터(gamma, beta, mean, std)를 Conv의 가중치(weight)와 bias로 합침
            std = (bn.running_var + bn.eps).sqrt()
            gamma = bn.weight
            beta = bn.bias
            running_mean = bn.running_mean

            fused_weight = kernel * (gamma / std).reshape(-1, 1, 1)
            fused_bias = (gamma / std) * (conv_bias - running_mean) + beta

            return fused_weight, fused_bias

        # 5개 브랜치를 모두 Fusing
        w_l1, b_l1 = _fuse(self.dw_large1, self.bn_large1)
        w_l2, b_l2 = _fuse(self.dw_large2, self.bn_large2)
        w_s1, b_s1 = _fuse(self.dw_small1, self.bn_small1)
        w_s2, b_s2 = _fuse(self.dw_small2, self.bn_small2)
        w_id, b_id = _fuse(None, self.bn_id)

        # 커널 크기가 작은 브랜치들(l2, s1, s2)의 가중치에 패딩(0)을 추가   
        # (k=31) + (k=29) -> (k=31) + (padding + k=29 + padding)
        pad_l2 = (self.kernel_size - self.dw_large2.kernel_size[0]) // 2
        w_l2 = F.pad(w_l2, (pad_l2, pad_l2))
        pad_s1 = (self.kernel_size - self.dw_small1.kernel_size[0]) // 2
        w_s1 = F.pad(w_s1, (pad_s1, pad_s1))
        pad_s2 = (self.kernel_size - self.dw_small2.kernel_size[0]) // 2
        w_s2 = F.pad(w_s2, (pad_s2, pad_s2))

        # 5개 브랜치의 가중치와 bias를 모두 더함. 이게 최종 'reparam_conv'의 파라미터가 됨.
        final_w = w_l1 + w_l2 + w_s1 + w_s2 + w_id
        final_b = b_l1 + b_l2 + b_s1 + b_s2 + b_id

        # 추론 모드용 reparam_conv를 생성
        reparam_padding = self.kernel_size // 2
        self.reparam_conv = nn.Conv1d(
            self.in_channels, self.in_channels, self.kernel_size,
            padding=reparam_padding, groups=self.in_channels, bias=True
        ).to(final_w.device)

        # 합쳐진 final_w와 final_b를 reparam_conv의 파라미터로 설정
        self.reparam_conv.weight.data = final_w
        self.reparam_conv.bias.data = final_b

        # 🚀 '추론' 모드로 전환
        self.deploy = True

        # 훈련에만 쓰였던 5개 브랜치 모듈들을 메모리에서 삭제
        for attr in ['dw_large1', 'bn_large1', 'dw_large2', 'bn_large2',
                     'dw_small1', 'bn_small1', 'dw_small2', 'bn_small2', 'bn_id']:
            if hasattr(self, attr):
                delattr(self, attr)

class ELKBackbone(nn.Module):
    """
    ELK Backbone built by stacking ELKBlocks
    `ELKBlock`을 여러 층 쌓아서 '백본'을 만듦
    """
    def __init__(self, in_channels=9, d_model=128, num_layers=6, kernel_size=31, dropout=0.1):
        super().__init__()
        # 'stem': 입력 (9채널)을 모델의 기본 차원(d_model=128)으로 바꾸는 첫 번째 Conv
        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, d_model, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )

        layers = []
        # ELKBlock과 Dropout을 num_layers(6)번 반복해서 쌓음
        for _ in range(num_layers):
            layers.append(ELKBlock(d_model, d_model, kernel_size=kernel_size))
            layers.append(nn.Dropout(dropout))

        self.elk_layers = nn.Sequential(*layers)
        self.out_channels = d_model  # 백본의 최종 출력 채널 수

    def forward(self, x):
        x = self.stem(x)  # (B, 9, 128) -> (B, 128, 128)
        x = self.elk_layers(x)  # 6개 ELK 블록 통과 (B, 128, 128) -> (B, 128, 128)
        return x

# ---------------------------
# 2) Sequential Cross-Attention (변경 없음)
# ---------------------------
# LayerNorm의 경량화 버전인 RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(d))  # 학습 가능한 scale 파라미터
    def forward(self, x):
        # 입력 x를 제곱-평균-제곱근(RMS)으로 정규화하고 scale(g)을 곱함
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.g

# 표준 멀티헤드 (크로스) 어텐션 모듈
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads=4, dropout=0.1, temperature=1.0):
        super().__init__()
        
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.temperature = temperature

        # Q(Query), K(Key), V(Value) 및 Output을 위한 Linear 프로젝션 레이어
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.scale = (self.head_dim * temperature) ** -0.5  # 어텐션 스코어 스케일링 값

    def forward(self, query, key, value, mask=None):
        B, Tq, D = query.shape  # (Batch, Query길이, Dim)
        _, Tkv, _ = key.shape  # (Batch, Key/Value길이, Dim)

        # 1. Q, K, V를 프로젝션하고 (B, T, D) -> (B, Heads, T, HeadDim)로 변형
        Q = self.q_proj(query).view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, Tkv, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, Tkv, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. 어텐션 스코어 계산 (Q @ K^T)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))

        # 3. Softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 4. 가중합 (Attention @ V)
        out = torch.matmul(attn_weights, V)

        # 5. (B, Heads, Tq, HeadDim) -> (B, Tq, D)로 원복 및 out_proj 통과
        out = out.transpose(1, 2).contiguous().view(B, Tq, D)

        return self.out_proj(out), attn_weights

# 1차 어텐션: 센서/축(Axis) 전문가 토큰과 Cross-Attention
class ImprovedSensorCrossAttention(nn.Module):
    def __init__(self, d_model, num_sensors=9, num_heads=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_sensors = num_sensors

        # '센서 전문가 토큰' (9개)과 '축 전문가 토큰' (3개)을 학습 가능한 파라미터로 생성
        self.sensor_tokens = nn.Parameter(torch.randn(1, num_sensors, d_model) * 0.02)
        self.axis_tokens = nn.Parameter(torch.randn(1, 3, d_model) * 0.02)

        self.sensor_relation = nn.Linear(d_model, d_model)

        # 토큰들끼리 정보를 섞기 위한 Self-Attention (SA) 모듈
        self.sensor_sa = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.axis_sa = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)

        # '데이터(x)'가 '전문가 토큰'과 정보를 섞기 위한 Cross-Attention (CA) 모듈
        self.cross_attn = MultiHeadCrossAttention(d_model, num_heads, dropout)


        self.axis_projection = nn.Linear(d_model, d_model)
        self.norm_s = RMSNorm(d_model)
        self.norm_a = RMSNorm(d_model)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

        # 트랜스포머의 FFN (Feed-Forward Network)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model), nn.Dropout(dropout),
        )

    def forward(self, x):  # x의 모양: (B, T=128, D=128)
        B, T, D = x.shape
        sensor_ctx = self.sensor_tokens.expand(B, -1, -1)
        x_pooled = x.mean(dim=1, keepdim=True)
        sensor_ctx = sensor_ctx + self.sensor_relation(x_pooled)
        s_norm = self.norm_s(sensor_ctx)
        s_attn, _ = self.sensor_sa(s_norm, s_norm, s_norm)
        sensor_ctx = sensor_ctx + s_attn

        x_axis = x.mean(dim=1)
        axis_features = [x_axis for _ in range(3)]
        axis_stack = torch.stack(axis_features, dim=1)

        axis_ctx = self.axis_tokens.expand(B, -1, -1)
        axis_ctx = axis_ctx + self.axis_projection(axis_stack)
        a_norm = self.norm_a(axis_ctx)
        a_attn, _ = self.axis_sa(a_norm, a_norm, a_norm)
        axis_ctx = axis_ctx + a_attn

        combined_ctx = torch.cat([sensor_ctx, axis_ctx], dim=1)
        residual = x
        x_norm = self.norm1(x)
        attn_out, _ = self.cross_attn(x_norm, combined_ctx, combined_ctx)
        x = residual + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

# 2차 어텐션: 시간(Temporal) 축으로 Self-Attention
class TemporalCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads=4, dropout=0.1, causal=False):
        super().__init__()
        # CA 모듈을 Self-Attention (SA) 용도로 사용
        self.cross_attn = MultiHeadCrossAttention(d_model, num_heads, dropout)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        self.causal = causal
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model), nn.Dropout(dropout),
        )

    def _create_causal_mask(self, T, device):
        return torch.tril(torch.ones(T, T, device=device))

    def forward(self, x):
        B, T, D = x.shape
        mask = self._create_causal_mask(T, x.device) if self.causal else None
        residual = x
        x_norm = self.norm1(x)

        # Query=x, Key=x, Value=x. 즉, Self-Attention
        # "128개 타임스텝들아, 너희끼리(x, x, x) 정보를 교환해"
        attn_out, _ = self.cross_attn(x_norm, x_norm, x_norm, mask)
        x = residual + attn_out  # Residual
        x = x + self.ffn(self.norm2(x))  # Residual + FFN
        return x

# 1차, 2차 어텐션을 순차적으로 실행
class SequentialCrossAttention(nn.Module):
    def __init__(self, d_model, num_sensors=9, num_heads=4, dropout=0.1):
        super().__init__()
        self.sensor_attn = ImprovedSensorCrossAttention(d_model, num_sensors, num_heads, dropout)
        self.temporal_attn = TemporalCrossAttention(d_model, num_heads, dropout, causal=False)

    def forward(self, x):
        x = self.sensor_attn(x)
        x = self.temporal_attn(x)
        return x

# ---------------------------
# 3) Temporal Prototype Attention (변경 없음) 어텐션으로 정제된 특징 맵을 '프로토타입'과 비교하여 압축합니다.
# ---------------------------
# TPA의 기본 로직
class TemporalPrototypeAttention(nn.Module):
    def __init__(self, dim, num_prototypes=16, seg_kernel=9, heads=4, dropout=0.1):
        super().__init__()
        assert dim % heads == 0
        self.dim, self.heads, self.head_dim = dim, heads, dim // heads
        self.num_prototypes = num_prototypes

        # '프로토타입' (학습 가능한 행동 템플릿)을 파라미터로 생성
        self.proto = nn.Parameter(torch.randn(num_prototypes, dim) * 0.02)
        pad = (seg_kernel - 1) // 2
        self.dw = nn.Conv1d(dim, dim, kernel_size=seg_kernel, padding=pad, groups=dim, bias=False)
        self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
        # Q, K, V 프로젝션 레이어
        self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(dim, dim, bias=False) for _ in range(4)]
        self.fuse = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def _tpa_core(self, x, proto):
        B, T, D = x.shape
        P = proto.size(1) if proto.dim() == 3 else proto.size(0)
        xloc = self.pw(self.dw(x.transpose(1, 2))).transpose(1, 2)
        K, V = self.k_proj(xloc), self.v_proj(xloc)
        Qp = self.q_proj(proto) if proto.dim() == 3 else self.q_proj(proto).unsqueeze(0).expand(B, -1, -1)

        def split_heads(t, is_kv=False):
            shape = (B, T if is_kv else P, self.heads, D // self.heads)
            return t.view(*shape).transpose(1, 2)

        Qh, Kh, Vh = split_heads(Qp, False), split_heads(K, True), split_heads(V, True)
        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.scale
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        proto_tokens = torch.matmul(attn, Vh).transpose(1, 2).contiguous().view(B, P, D)
        z = self.fuse(proto_tokens.mean(dim=1) + proto_tokens.max(dim=1).values)
        z = self.out_proj(z)
        aux = {"attn": attn, "align_peak_mean": attn.amax(dim=-1).mean().detach()}
        return z, aux  # 최종 압축 벡터 z 반환

    def forward(self, x):
        return self._tpa_core(x, self.proto)


# TPA의 '클래스 조건부' 업그레이드 버전
class ClassConditionalTPA(TemporalPrototypeAttention):
    def __init__(self, dim, num_classes, p_shared=8, p_class=4, **kw):
        # '공유' 프로토타입 8개 + '클래스별' 프로토타입 (4 * 6=24개) = 총 32개
        super().__init__(dim, num_prototypes=p_shared + num_classes * p_class, **kw)
        self.p_shared, self.p_class, self.num_classes = p_shared, p_class, num_classes

    # 특정 클래스(cls_idx)의 전용 프로토타입 4개를 가져오는 함수
    def _slice_class(self, cls_idx):
        base = self.p_shared + cls_idx * self.p_class
        return self.proto[base: base + self.p_class]

    def forward(self, x, y=None, logits=None):
        B = x.size(0)
        shared = self.proto[:self.p_shared]
        if self.training and y is not None:
            # 🧠 훈련 시: 정답(y)을 아니까, 정답 클래스의 전용 프로토타입을 가져옴
            cls_proto = torch.stack([self._slice_class(y[i].item()) for i in range(B)], 0)
            # (공통 8개 + 정답 전용 4개) = 12개 프로토타입을 Query로 사용
            proto = torch.cat([shared.unsqueeze(0).expand(B, -1, -1), cls_proto], dim=1)
        else:
            # 🚀 평가 시: 정답(y)을 모름
            if logits is None:
                # (예외) logits도 없으면 공통 프로토타입 8개만 사용
                proto = shared.unsqueeze(0).expand(B, -1, -1)
            else:
                # 백본이 예측한 logits (B, 6)를 softmax(pi) (B, 6)로 변환
                pi = logits.softmax(dim=-1)
                # 6개 클래스의 전용 프로토타입 뱅크 (6, 4, D)
                class_bank = torch.stack([self._slice_class(c) for c in range(self.num_classes)], 0)
                # (B, 6) @ (6, 4, D) -> (B, 4, D)
                # 6개 클래스의 전용 프로토타입 4개를 logits 확률로 '섞어서' 만듦
                mixed = torch.einsum('bc,cpd->bpd', pi, class_bank)
                # (공통 8개 + 섞인 4개) = 12개 프로토타입을 Query로 사용
                proto = torch.cat([shared.unsqueeze(0).expand(B, -1, -1), mixed], dim=1)

        return self._tpa_core(x, proto)  # 선택된 'proto'를 Query로 사용하여 TPA 핵심 로직 수행

# ---------------------------
# 4) Loss Function (변경 없음)
# ---------------------------
class ImprovedClsLoss(nn.Module):
    def __init__(self, use_focal=True, alpha=0.25, gamma=2.0, init_loss_weight=0.4, label_smoothing=0.1):
        super().__init__()
        self.use_focal, self.alpha, self.gamma = use_focal, alpha, gamma
        self.init_loss_weight, self.label_smoothing = init_loss_weight, label_smoothing

    def forward(self, logits, labels, aux_info, aux_weight_multiplier=1.0):
        n_classes = logits.size(-1)
        if self.label_smoothing > 0:
            one_hot = F.one_hot(labels, num_classes=n_classes).float()
            smooth_label = one_hot * (1 - self.label_smoothing) + self.label_smoothing / n_classes
            if self.use_focal:
                log_probs = F.log_softmax(logits, dim=-1)
                pt = torch.exp(log_probs)
                focal_weight = (1 - pt) ** self.gamma
                loss = -(self.alpha * focal_weight * smooth_label * log_probs).sum(dim=-1).mean()
            else:
                loss = -(smooth_label * F.log_softmax(logits, dim=-1)).sum(dim=-1).mean()
        else:
            ce = F.cross_entropy(logits, labels, reduction="none")
            if self.use_focal:
                pt = torch.exp(-ce)
                loss = (self.alpha * (1 - pt)**self.gamma * ce).mean()
            else:
                loss = ce.mean()

        total_loss = loss
        if "logits_init" in aux_info:
            loss_init = F.cross_entropy(aux_info["logits_init"], labels)
            total_loss = loss + self.init_loss_weight * loss_init

        return total_loss, {
            "classification_loss": float(loss.item()),
            "total_loss": float(total_loss.item()),
            "align_peak_mean": float(aux_info.get("align_peak_mean", 0.)),
        }

# ---------------------------
# 5) Hybrid Model (TCN -> ELK) 모든 부품을 조립하는 최종 모델
# ---------------------------
class ELK_SequentialAttn_TPA(nn.Module):
    def __init__(self, nvars, seq_len, num_classes,
                 num_elk_layers, elk_kernel_size, # ELK 파라미터
                 d_model, heads, dropout,
                 num_prototypes, seg_kernel, p_shared, p_class,
                 use_class_conditional, use_cross_attention):
        super().__init__()
        self.use_class_conditional = use_class_conditional
        self.use_cross_attention = use_cross_attention
        
        # 부품 1: ELKBackbone (섹션 1)
        # TCN Backbone -> ELK Backbone 으로 교체
        self.backbone = ELKBackbone(
            in_channels=nvars,
            d_model=d_model,
            num_layers=num_elk_layers,
            kernel_size=elk_kernel_size,
            dropout=dropout
        )

        # 백본 출력 채널(128) -> 어텐션 입력 채널(128)로 맞춰주는 1x1 Conv
        self.proj = nn.Conv1d(self.backbone.out_channels, d_model, kernel_size=1)

        # 부품 2: SequentialCrossAttention (섹션 2)
        if use_cross_attention:
            self.cross_attn = SequentialCrossAttention(
                d_model, nvars, num_heads=heads, dropout=dropout
            )

        # 부품 3: TemporalPrototypeAttention (섹션 3)
        if use_class_conditional:
            self.tpa = ClassConditionalTPA(
                d_model, num_classes, p_shared, p_class,
                seg_kernel=seg_kernel, heads=heads, dropout=dropout  # 조건부 TPA
            )
        else:
            self.tpa = TemporalPrototypeAttention(
                d_model, num_prototypes, seg_kernel, heads, dropout  # 일반 TPA
            )

        self.head_init = nn.Linear(d_model, num_classes)  # 분류 헤드 1: 백본 출력을 위한 초기 분류기
        self.head_final = nn.Linear(d_model, num_classes)  # 분류 헤드 2: TPA 출력을 위한 최종 분류기

    def forward(self, x, labels=None):
        x = self.backbone(x) # self.tcn -> self.backbone
        x = self.proj(x)
        x = x.transpose(1, 2)

        x_pooled = x.mean(dim=1)
        logits_init = self.head_init(x_pooled)

        if self.use_cross_attention:
            x = self.cross_attn(x)

        if self.use_class_conditional:
            z, aux = self.tpa(x, y=labels, logits=logits_init)
        else:
            z, aux = self.tpa(x)

        logits_final = self.head_final(z)
        aux['logits_init'] = logits_init
        return logits_final, aux

    def reparameterize(self):
        """추론을 위해 모든 ELK 블록을 재매개변수화합니다."""
        for m in self.modules():
            if isinstance(m, ELKBlock):
                m.reparameterize()

# ---------------------------
# 6) Train / Eval (변경 없음)
# ---------------------------
def train_epoch(model, loader, criterion, optim, scheduler, device, accumulation_steps=4):
    model.train()
    tot, correct, total = 0.0, 0, 0
    optim.zero_grad()
    for i, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        logits, aux = model(x, labels=y)
        loss, _ = criterion(logits, y, aux, 1.0)
        loss = loss / accumulation_steps
        loss.backward()

        if (i + 1) % accumulation_steps == 0 or i == len(loader) - 1:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optim.step()
            if scheduler: scheduler.step()
            optim.zero_grad()

        tot += loss.item() * accumulation_steps
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return tot / len(loader), 100 * correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    tot, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, aux = model(x, labels=None)
        loss, _ = criterion(logits, y, aux, 1.0)
        tot += loss.item()
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return tot / len(loader), 100 * correct / total

# ---------------------------
# 7) Main (모델 및 파라미터 수정)
# ---------------------------
def main():
    import gc
    torch.cuda.empty_cache(); gc.collect()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("UCI-HAR Training: ELK Backbone + Sequential Cross-Attention + TPA")
    print(f"Device: {device}")

    uci_root = "/content/drive/MyDrive/Colab Notebooks/UCI-HAR/UCI-HAR"

    if not os.path.exists(uci_root):
        print(f"Error: UCI-HAR data not found at '{uci_root}'")
        return

    train_ds = UCIHARDataset(uci_root, "train")
    test_ds = UCIHARDataset(uci_root, "test")

    train_indices, val_indices = train_test_split(
        np.arange(len(train_ds)), test_size=0.2, random_state=42, stratify=train_ds.y
    )

    mu, sd = fit_channel_stats(train_ds.X[train_indices])
    train_ds.set_stats((mu, sd))
    test_ds.set_stats((mu, sd)) # 테스트셋에도 훈련셋 통계 적용

    # Validation셋을 위한 별도 Dataset 객체 생성 후 통계 적용
    val_ds = UCIHARDataset(uci_root, "train")
    val_ds.set_stats((mu, sd))

    train_subset = Subset(train_ds, train_indices)
    val_subset = Subset(val_ds, val_indices)

    print(f"\nDataset Split: Train={len(train_subset)}, Val={len(val_subset)}, Test={len(test_ds)}")

    # 모델 정의 (ELK 파라미터 사용)
    model = ELK_SequentialAttn_TPA(
        nvars=9,
        seq_len=128,
        num_classes=6,
        d_model=128,
        heads=4,
        dropout=0.2,
        # --- ELK Parameters ---
        num_elk_layers=6,
        elk_kernel_size=31,
        # --- TPA Parameters ---
        num_prototypes=8,
        seg_kernel=3,
        p_shared=6,
        p_class=4,
        use_class_conditional=True,
        use_cross_attention=True
    ).to(device)

    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("Architecture: ELK Backbone + Sequential Cross-Attention + TPA")

    criterion = ImprovedClsLoss(use_focal=True, alpha=0.25, gamma=2.0, init_loss_weight=0.4, label_smoothing=0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)

    batch_size = 64
    accumulation_steps = 2
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size*2, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size*2, shuffle=False, num_workers=2, pin_memory=True)

    max_epochs = 60
    total_steps = (len(train_loader) // accumulation_steps) * max_epochs
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=2e-3, total_steps=total_steps, pct_start=0.2
    )

    best_val_acc, best_epoch = 0.0, 0
    print(f"\nStarting training for {max_epochs} epochs...")
    print(f"Effective batch size: {batch_size * accumulation_steps}")

    for epoch in range(1, max_epochs + 1):
        lr = optimizer.param_groups[0]['lr']
        print(f"\nEpoch {epoch}/{max_epochs} | LR: {lr:.6f}")
        tr_loss, tr_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, accumulation_steps)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        print(f"  Train: loss {tr_loss:.4f} | acc {tr_acc:.2f}%")
        print(f"  Val  : loss {val_loss:.4f} | acc {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc, best_epoch = val_acc, epoch
            torch.save(model.state_dict(), "best_model_elk_valsel.pth")
            print(f"  ✓ New best validation accuracy: {best_val_acc:.2f}% (model saved)")

    print(f"\nTraining completed! Best val acc: {best_val_acc:.2f}% at epoch {best_epoch}")

    print("\n" + "="*60)
    print("FINAL TEST SET EVALUATION (Unseen Data)")
    print("→ Reparameterizing ELK blocks for inference speed-up...")
    print("="*60)

    # 저장된 최고의 모델을 불러온 후, 재매개변수화 수행
    model.load_state_dict(torch.load("best_model_elk_valsel.pth"))
    model.reparameterize()

    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"\nFinal Results on Test Set:")
    print(f"  - Test Loss    : {test_loss:.4f}")
    print(f"  - Test Accuracy: {test_acc:.2f}%")

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    main()