In [None]:
# ============================================================
# LTE-V (.mat via HDF5) - Spec + Siamese (FAST GPU PIPELINE)
#
# 目标设置（按你的要求）：
# 1) Train online augmentation（每样本随机）：
#    - Multipath: 指数 PDP 的 TDL，多径 RMS delay spread ~ Uniform[5,300] ns
#    - Doppler: v ~ Uniform[0,120] km/h -> fd=(v/c)*fc
#    - AWGN: SNR ~ Uniform[-40,20] dB
# 2) Test：
#    - Doppler 固定 120 km/h
#    - AWGN SNR sweep: 20,15,...,-40 dB
#    - 默认不额外加 multipath（如要测试也加：TEST_USE_MULTIPATH=True）
# 3) 模型与损失：
#    - 输入：log|STFT| 2D spectrogram（resize 到 64x64）
#    - Siamese 两分支共享权重
#    - Loss: L = NT-Xent(tau=0.05) + CE
# 4) 训练策略：
#    - Adam lr=3e-4
#    - ReduceLROnPlateau(val_loss, patience=10, factor=0.5)
#    - val loss 30 epoch 不降 -> early stop
# 5) 数据划分：
#    - 优先文件级 stratify split；不满足则回退样本级 stratify
#
# 加速点：
# - Dataset 只输出 IQ，不在 __getitem__ 中 STFT
# - 增强 + STFT 全部 GPU batch 计算
# - AMP 混合精度（NT-Xent 强制 fp32，避免溢出）
# ============================================================

import os
import glob
import csv
import random
import numpy as np
from datetime import datetime
from collections import defaultdict, Counter

import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold, train_test_split

# ----------------------------
# 0) 全局配置
# ----------------------------
SEED = 42

def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

USE_AMP = (DEVICE.type == "cuda")

# LTE-V 数据根目录
DATA_PATH = "E:/rf_datasets_IQ/"
RECURSIVE_GLOB = True

# RF 参数（与你 XFR LTE-V 脚本一致）
FS = 5e6
FC = 5.9e9

# 训练参数
BATCH_SIZE = 256          # 你可根据显存调整；NT-Xent 的 sim 矩阵是 (2B)^2，过大可能OOM
LR = 3e-4
WEIGHT_DECAY = 0.0
MAX_EPOCHS = 200
N_SPLITS = 5

LR_PATIENCE = 10
LR_FACTOR = 0.5
ES_PATIENCE = 30

# Siamese/对比学习
TAU = 0.05
LAMBDA_CL = 1.0
LAMBDA_CE = 1.0

# 训练增强范围
AUG_USE_MULTIPATH = True
AUG_USE_DOPPLER   = True
AUG_USE_AWGN      = True

RMS_DS_NS_RANGE = (5.0, 300.0)
TRAIN_V_KMH_RANGE = (0.0, 120.0)
TRAIN_SNR_DB_RANGE = (-40.0, 20.0)

# 测试增强
TEST_V_KMH_FIXED = 120.0
TEST_USE_MULTIPATH = False
TEST_RMS_DS_NS_RANGE = RMS_DS_NS_RANGE
TEST_SNR_LIST = list(range(20, -45, -5))

# Spectrogram
SPEC_NFFT = 128
SPEC_WIN  = 128
SPEC_HOP  = 16
SPEC_SIZE = 64

# multipath taps
MAX_TAPS = 16

# 读取控制
MAX_SAMPLES_PER_FILE = None  # 如 2000；None 表示全取

# DataLoader workers（Windows 建议先用 0）
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_EVAL  = 0

SAVE_ROOT = "./training_results"
os.makedirs(SAVE_ROOT, exist_ok=True)
SCRIPT_NAME = "LTEV_SpecSiamese_FASTGPU_SNRsweep_Doppler120"

RETURN_CM = False  # 如需 confusion matrix，可设 True（会更慢/更占内存）


# ----------------------------
# 1) 读取 LTE-V HDF5(.mat)
# ----------------------------
def decode_txid(txid_arr: np.ndarray) -> str:
    txid_arr = txid_arr.flatten()
    chars = []
    for c in txid_arr:
        if int(c) == 0:
            continue
        try:
            chars.append(chr(int(c)))
        except Exception:
            pass
    s = "".join(chars).strip()
    return s if s else "UNKNOWN"

def load_dmrs_complex_from_file(h5f: h5py.File) -> np.ndarray:
    rfDataset = h5f["rfDataset"]
    dmrs_obj = rfDataset["dmrs"]

    if isinstance(dmrs_obj, h5py.Dataset) and dmrs_obj.dtype.fields is not None:
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    elif isinstance(dmrs_obj, h5py.Group):
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    else:
        tmp = dmrs_obj[()]
        if hasattr(tmp, "dtype") and tmp.dtype.fields is not None:
            real = tmp["real"]
            imag = tmp["imag"]
        else:
            raise RuntimeError("无法解析 dmrs：既不是 compound dataset 也不是 group(real/imag).")

    dmrs_complex = np.asarray(real + 1j * imag)

    if dmrs_complex.ndim != 2:
        raise RuntimeError(f"dmrs_complex 维度异常: {dmrs_complex.shape}")

    # 常见：读出来是 (L,N) 则转置为 (N,L)
    if dmrs_complex.shape[0] <= 2048 and dmrs_complex.shape[1] > dmrs_complex.shape[0]:
        dmrs_complex = dmrs_complex.T

    return dmrs_complex.astype(np.complex64)

def load_ltev_dataset(data_path: str, recursive: bool = True, max_samples_per_file=None):
    if recursive:
        mat_files = glob.glob(os.path.join(data_path, "**", "*.mat"), recursive=True)
    else:
        mat_files = glob.glob(os.path.join(data_path, "*.mat"))

    if len(mat_files) == 0:
        raise RuntimeError(f"未找到 .mat 文件：{data_path}")

    signals_by_file = []
    label_str_by_file = []
    file_paths = []

    print(f"[INFO] Found {len(mat_files)} .mat files")

    for fp in mat_files:
        try:
            with h5py.File(fp, "r") as f:
                rfDataset = f["rfDataset"]
                txid_arr = np.asarray(rfDataset["txID"][()])
                tx_str = decode_txid(txid_arr)

                dmrs_complex = load_dmrs_complex_from_file(f)
                if max_samples_per_file is not None:
                    dmrs_complex = dmrs_complex[: int(max_samples_per_file)]

                if dmrs_complex.shape[0] == 0:
                    continue

                signals_by_file.append(dmrs_complex)  # (N,L) complex64
                label_str_by_file.append(tx_str)
                file_paths.append(fp)
        except Exception as e:
            print(f"[WARN] Skip file due to read/parse error: {fp} | {repr(e)}")

    if len(signals_by_file) == 0:
        raise RuntimeError("所有文件都读取失败或为空，请检查数据结构。")

    # 统计
    cnt = Counter(label_str_by_file)
    print("[INFO] txID classes:", len(cnt))
    for k, v in sorted(cnt.items(), key=lambda x: (-x[1], x[0])):
        print(f"  {k}: {v} files")

    Ls = [arr.shape[1] for arr in signals_by_file]
    mode_L = Counter(Ls).most_common(1)[0][0]
    print(f"[INFO] DMRS length stats: min={min(Ls)}, max={max(Ls)}, mode={mode_L}")

    # 为了 batch 化训练：强制所有文件的 L 一致（trim / zero-pad 到 mode_L）
    for i in range(len(signals_by_file)):
        x = signals_by_file[i]
        if x.shape[1] == mode_L:
            continue
        if x.shape[1] > mode_L:
            signals_by_file[i] = x[:, :mode_L].astype(np.complex64)
        else:
            pad = mode_L - x.shape[1]
            signals_by_file[i] = np.pad(x, ((0,0),(0,pad)), mode="constant").astype(np.complex64)

    return signals_by_file, label_str_by_file, file_paths, mode_L


# ----------------------------
# 2) split：优先文件级 stratify，失败回退样本级
# ----------------------------
def build_label_mapping(label_str_by_file):
    label_list = sorted(list(set(label_str_by_file)))
    label_to_idx = {lab: i for i, lab in enumerate(label_list)}
    return label_list, label_to_idx

def build_all_sample_list(signals_by_file):
    sample_list = []
    for fi, arr in enumerate(signals_by_file):
        n = arr.shape[0]
        sample_list.extend([(fi, si) for si in range(n)])
    return sample_list

def split_train_test(signals_by_file, label_idx_by_file, test_size=0.25):
    n_files = len(signals_by_file)
    file_indices = np.arange(n_files)
    y_files = np.asarray(label_idx_by_file, dtype=np.int64)

    per_class_files = Counter(y_files.tolist())
    can_file_split = all(v >= 2 for v in per_class_files.values())

    if can_file_split:
        try:
            train_fi, test_fi = train_test_split(
                file_indices, test_size=test_size, stratify=y_files, random_state=SEED
            )
            train_fi = np.array(train_fi, dtype=np.int64)
            test_fi  = np.array(test_fi, dtype=np.int64)

            train_samples, test_samples = [], []
            for fi in train_fi:
                n = signals_by_file[int(fi)].shape[0]
                train_samples.extend([(int(fi), si) for si in range(n)])
            for fi in test_fi:
                n = signals_by_file[int(fi)].shape[0]
                test_samples.extend([(int(fi), si) for si in range(n)])

            print("[INFO] Split mode: FILE-LEVEL")
            print(f"[INFO] Train files={len(train_fi)}, Test files={len(test_fi)}")
            return "file", train_fi, test_fi, train_samples, test_samples
        except Exception as e:
            print(f"[WARN] File-level stratified split failed, fallback to sample-level. Reason: {repr(e)}")

    all_samples = build_all_sample_list(signals_by_file)
    y_samples = np.array([label_idx_by_file[fi] for (fi, _) in all_samples], dtype=np.int64)
    tr_i, te_i = train_test_split(
        np.arange(len(all_samples)), test_size=test_size, stratify=y_samples, random_state=SEED
    )
    train_samples = [all_samples[i] for i in tr_i]
    test_samples  = [all_samples[i] for i in te_i]

    print("[INFO] Split mode: SAMPLE-LEVEL (fallback)")
    print(f"[INFO] Train samples={len(train_samples)}, Test samples={len(test_samples)}")
    return "sample", None, None, train_samples, test_samples


# ----------------------------
# 3) Dataset：只返回 IQ（用于 GPU batch 增强+STFT）
# ----------------------------
class LTEVSiameseIQDataset(Dataset):
    """
    训练：返回 (iq1, iq2, label)
    正样本构造：同一 label 下优先不同文件；否则同文件不同帧
    """
    def __init__(self, signals_by_file, label_idx_by_file, sample_list):
        self.signals_by_file = signals_by_file
        self.label_idx_by_file = label_idx_by_file
        self.sample_list = sample_list

        self.label_to_files = defaultdict(list)
        self.file_to_count = {}
        present_files = sorted(list(set([fi for (fi, _) in sample_list])))
        for fi in present_files:
            lab = label_idx_by_file[fi]
            self.label_to_files[lab].append(fi)
            self.file_to_count[fi] = signals_by_file[fi].shape[0]

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        fi1, si1 = self.sample_list[idx]
        lab = int(self.label_idx_by_file[fi1])

        sig1 = self.signals_by_file[fi1][si1]  # complex (L,)

        files = self.label_to_files[lab]
        if len(files) >= 2:
            fi2 = fi1
            while fi2 == fi1:
                fi2 = random.choice(files)
        else:
            fi2 = fi1

        n2 = self.file_to_count[fi2]
        if fi2 == fi1 and n2 >= 2:
            si2 = si1
            while si2 == si1:
                si2 = random.randrange(n2)
        else:
            si2 = random.randrange(n2)

        sig2 = self.signals_by_file[fi2][si2]

        # 输出 IQ float32 (L,2)
        iq1 = np.stack([sig1.real, sig1.imag], axis=-1).astype(np.float32)
        iq2 = np.stack([sig2.real, sig2.imag], axis=-1).astype(np.float32)
        return iq1, iq2, np.int64(lab)

class LTEVSingleIQDataset(Dataset):
    """
    验证/测试：返回 (iq, label)，增强与 STFT 在 GPU eval 函数里做
    """
    def __init__(self, signals_by_file, label_idx_by_file, sample_list):
        self.signals_by_file = signals_by_file
        self.label_idx_by_file = label_idx_by_file
        self.sample_list = sample_list

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        fi, si = self.sample_list[idx]
        lab = int(self.label_idx_by_file[fi])
        sig = self.signals_by_file[fi][si]
        iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)
        return iq, np.int64(lab)


# ----------------------------
# 4) GPU batch 增强：Multipath / Doppler / AWGN
# ----------------------------
def _to_complex(iq_b: torch.Tensor) -> torch.Tensor:
    return iq_b[..., 0].to(torch.float32) + 1j * iq_b[..., 1].to(torch.float32)

def _from_complex(sig: torch.Tensor) -> torch.Tensor:
    return torch.stack([sig.real, sig.imag], dim=-1)

def batch_normalize_power(iq_b: torch.Tensor) -> torch.Tensor:
    # power: (B,1) -> scale: (B,1,1)
    power = (iq_b[..., 0] ** 2 + iq_b[..., 1] ** 2).mean(dim=1, keepdim=True) + 1e-12
    scale = torch.rsqrt(power).unsqueeze(-1)
    return iq_b * scale

def batch_apply_doppler(iq_b: torch.Tensor, v_kmh: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)

    c = 3e8
    v = v_kmh / 3.6
    fd = (v / c) * FC  # (B,)
    n = torch.arange(L, device=iq_b.device, dtype=torch.float32).unsqueeze(0)  # (1,L)
    phase = torch.exp(1j * 2.0 * np.pi * fd.unsqueeze(1).to(torch.float32) * n / FS)  # (B,L)
    sig = sig * phase
    return _from_complex(sig)

def _grouped_conv1d_real(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    # x: (B,1,L), w: (B,1,K) -> grouped conv -> (B,L+K-1)
    B, _, L = x.shape
    _, _, K = w.shape
    x2 = x.permute(1, 0, 2).contiguous()      # (1,B,L)
    y2 = F.conv1d(x2, w, padding=K - 1, groups=B)  # (1,B,L+K-1)
    return y2.squeeze(0)  # (B, L+K-1)

def batch_apply_multipath(iq_b: torch.Tensor, rms_ns: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    device = iq_b.device

    rms_s = rms_ns * 1e-9
    rms_samples = (rms_s * FS).clamp(min=1e-3)  # (B,)
    k = torch.arange(MAX_TAPS, device=device, dtype=torch.float32).unsqueeze(0)  # (1,K)
    p = torch.exp(-k / rms_samples.unsqueeze(1))  # (B,K)
    p = p / (p.sum(dim=1, keepdim=True) + 1e-12)

    hr = torch.randn(B, MAX_TAPS, device=device) * torch.sqrt(p / 2.0)
    hi = torch.randn(B, MAX_TAPS, device=device) * torch.sqrt(p / 2.0)

    hpow = (hr**2 + hi**2).sum(dim=1, keepdim=True) + 1e-12
    norm = torch.rsqrt(hpow)
    hr = hr * norm
    hi = hi * norm

    xr = iq_b[..., 0]
    xi = iq_b[..., 1]

    xr_ = xr.unsqueeze(1)
    xi_ = xi.unsqueeze(1)
    hr_ = hr.unsqueeze(1)
    hi_ = hi.unsqueeze(1)

    xr_hr = _grouped_conv1d_real(xr_, hr_)
    xi_hi = _grouped_conv1d_real(xi_, hi_)
    xr_hi = _grouped_conv1d_real(xr_, hi_)
    xi_hr = _grouped_conv1d_real(xi_, hr_)

    yr = xr_hr - xi_hi
    yi = xr_hi + xi_hr

    yr = yr[:, :L]
    yi = yi[:, :L]
    return torch.stack([yr, yi], dim=-1)

def batch_add_awgn(iq_b: torch.Tensor, snr_db: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)
    p = (sig.real**2 + sig.imag**2).mean(dim=1) + 1e-12
    npow = p / (10.0 ** (snr_db / 10.0))
    std = torch.sqrt(npow / 2.0).unsqueeze(1)
    noise = std * (torch.randn(B, L, device=iq_b.device) + 1j * torch.randn(B, L, device=iq_b.device))
    sig = sig + noise
    return _from_complex(sig)

def augment_train_batch(iq_b: torch.Tensor) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if AUG_USE_MULTIPATH:
        rms = (RMS_DS_NS_RANGE[1] - RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms)

    if AUG_USE_DOPPLER:
        v = (TRAIN_V_KMH_RANGE[1] - TRAIN_V_KMH_RANGE[0]) * torch.rand(B, device=iq_b.device) + TRAIN_V_KMH_RANGE[0]
        iq_b = batch_apply_doppler(iq_b, v)

    if AUG_USE_AWGN:
        snr = (TRAIN_SNR_DB_RANGE[1] - TRAIN_SNR_DB_RANGE[0]) * torch.rand(B, device=iq_b.device) + TRAIN_SNR_DB_RANGE[0]
        iq_b = batch_add_awgn(iq_b, snr)

    return iq_b

def augment_test_batch(iq_b: torch.Tensor, snr_db: float) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if TEST_USE_MULTIPATH:
        rms = (TEST_RMS_DS_NS_RANGE[1] - TEST_RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + TEST_RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms)

    v = torch.full((B,), float(TEST_V_KMH_FIXED), device=iq_b.device)
    iq_b = batch_apply_doppler(iq_b, v)

    snr = torch.full((B,), float(snr_db), device=iq_b.device)
    iq_b = batch_add_awgn(iq_b, snr)
    return iq_b


# ----------------------------
# 5) GPU batch STFT -> logmag -> resize
# ----------------------------
_WINDOW_CACHE = {}

def get_hann_window(device: torch.device):
    key = (device.type, device.index, SPEC_WIN)
    if key not in _WINDOW_CACHE:
        _WINDOW_CACHE[key] = torch.hann_window(SPEC_WIN, periodic=True, device=device)
    return _WINDOW_CACHE[key]

def iq_to_logspec_batch(iq_b: torch.Tensor) -> torch.Tensor:
    # iq_b: (B,L,2) float32
    sig = _to_complex(iq_b).to(torch.complex64)  # (B,L)

    win = get_hann_window(iq_b.device)
    S = torch.stft(
        sig,
        n_fft=SPEC_NFFT,
        hop_length=SPEC_HOP,
        win_length=SPEC_WIN,
        window=win,
        center=True,
        return_complex=True
    )  # (B,F,T)

    mag = torch.abs(S) + 1e-12
    logmag = torch.log(mag)  # (B,F,T)

    mu = logmag.mean(dim=(1,2), keepdim=True)
    sd = logmag.std(dim=(1,2), keepdim=True) + 1e-6
    logmag = (logmag - mu) / sd
    logmag = torch.nan_to_num(logmag, nan=0.0, posinf=0.0, neginf=0.0)

    x = logmag.unsqueeze(1)  # (B,1,F,T)
    x = F.interpolate(x, size=(SPEC_SIZE, SPEC_SIZE), mode="bilinear", align_corners=False)
    return x


# ----------------------------
# 6) 模型
# ----------------------------
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)

        self.down = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(identity)
        return self.relu(out + identity)

class SpecFeatureNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.relu  = nn.ReLU(inplace=True)

        self.b1 = BasicBlock(32, 32, stride=1)
        self.b2 = BasicBlock(32, 32, stride=1)
        self.b3 = BasicBlock(32, 64, stride=1)
        self.b4 = BasicBlock(64, 64, stride=1)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(64, 512)
        self.fc2 = nn.Linear(512, 256)
        self.cls = nn.Linear(256, num_classes)

    def forward_once(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # (B,64)
        x = F.relu(self.fc1(x))
        z = self.fc2(x)                          # (B,256)
        logits = self.cls(z)                     # (B,K)
        return z, logits

    def forward(self, x1, x2=None):
        z1, p1 = self.forward_once(x1)
        if x2 is None:
            return z1, p1
        z2, p2 = self.forward_once(x2)
        return z1, p1, z2, p2


# ----------------------------
# 7) 损失与评估（NT-Xent：强制 fp32，避免 AMP 溢出）
# ----------------------------
def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    # 关键：禁用 autocast，强制 float32，避免 fp16 下对角线填充值溢出
    with torch.cuda.amp.autocast(enabled=False):
        z1 = z1.float()
        z2 = z2.float()

        N = z1.size(0)
        z = torch.cat([z1, z2], dim=0)          # (2N,D)
        z = F.normalize(z, dim=1)

        sim = (z @ z.T) / float(tau)            # float32
        sim.fill_diagonal_(torch.finfo(sim.dtype).min)

        pos = torch.arange(2 * N, device=z.device)
        pos = (pos + N) % (2 * N)

        log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
        loss = -log_prob[torch.arange(2 * N, device=z.device), pos]
        return loss.mean()

@torch.no_grad()
def eval_single_iq(model: SpecFeatureNet, loader: DataLoader, num_classes: int, mode: str, snr_db: float = None):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total, correct = 0, 0
    loss_sum, nb = 0.0, 0
    all_y, all_p = [], []

    for iq, y in loader:
        iq = iq.to(DEVICE, non_blocking=True)
        y  = y.to(DEVICE, non_blocking=True)

        if mode == "test":
            iq = augment_test_batch(iq, snr_db=float(snr_db))

        spec = iq_to_logspec_batch(iq)

        _, logits = model(spec, None)
        loss = ce(logits, y)

        loss_sum += loss.item()
        nb += 1

        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        if RETURN_CM:
            all_y.append(y.detach().cpu().numpy())
            all_p.append(pred.detach().cpu().numpy())

    acc = 100.0 * correct / max(total, 1)
    cm = None
    if RETURN_CM and total > 0:
        import numpy as _np
        from sklearn.metrics import confusion_matrix
        all_y = _np.concatenate(all_y) if all_y else _np.array([])
        all_p = _np.concatenate(all_p) if all_p else _np.array([])
        cm = confusion_matrix(all_y, all_p, labels=list(range(num_classes)))
    return (loss_sum / max(nb, 1)), acc, cm


# ----------------------------
# 8) 主训练：KFold + Test SNR sweep
# ----------------------------
def train_kfold_ltev_spec_siamese_fast(data_path: str):
    signals_by_file, label_str_by_file, file_paths, L = load_ltev_dataset(
        data_path, recursive=RECURSIVE_GLOB, max_samples_per_file=MAX_SAMPLES_PER_FILE
    )

    label_list, label_to_idx = build_label_mapping(label_str_by_file)
    label_idx_by_file = [label_to_idx[s] for s in label_str_by_file]
    num_classes = len(label_list)

    mode, train_files, test_files, train_samples, test_samples = split_train_test(
        signals_by_file, label_idx_by_file, test_size=0.25
    )

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_dir = f"{timestamp}_{SCRIPT_NAME}"
    save_folder = os.path.join(SAVE_ROOT, save_dir)
    os.makedirs(save_folder, exist_ok=True)

    fd_test = (TEST_V_KMH_FIXED / 3.6) / 3e8 * FC

    with open(os.path.join(save_folder, "config.txt"), "w", encoding="utf-8") as f:
        f.write(f"DEVICE={DEVICE}\nAMP={USE_AMP}\n")
        f.write(f"DATA_PATH={data_path}\nRECURSIVE_GLOB={RECURSIVE_GLOB}\n")
        f.write(f"MAX_SAMPLES_PER_FILE={MAX_SAMPLES_PER_FILE}\n")
        f.write(f"DMRS_LEN(L)={L}\n")
        f.write(f"num_classes={num_classes}\nsplit_mode={mode}\n")
        f.write(f"BATCH_SIZE={BATCH_SIZE}, LR={LR}, WEIGHT_DECAY={WEIGHT_DECAY}\n")
        f.write(f"MAX_EPOCHS={MAX_EPOCHS}, N_SPLITS={N_SPLITS}\n")
        f.write(f"LR_PATIENCE={LR_PATIENCE}, LR_FACTOR={LR_FACTOR}, ES_PATIENCE={ES_PATIENCE}\n")
        f.write(f"TAU={TAU}, LAMBDA_CL={LAMBDA_CL}, LAMBDA_CE={LAMBDA_CE}\n")
        f.write(f"FS={FS}, FC={FC}\n")
        f.write(f"AUG_USE_MULTIPATH={AUG_USE_MULTIPATH}, RMS_DS_NS_RANGE={RMS_DS_NS_RANGE}, MAX_TAPS={MAX_TAPS}\n")
        f.write(f"AUG_USE_DOPPLER={AUG_USE_DOPPLER}, TRAIN_V_KMH_RANGE={TRAIN_V_KMH_RANGE}\n")
        f.write(f"AUG_USE_AWGN={AUG_USE_AWGN}, TRAIN_SNR_DB_RANGE={TRAIN_SNR_DB_RANGE}\n")
        f.write(f"TEST_V_KMH_FIXED={TEST_V_KMH_FIXED}, fd_test={fd_test}\n")
        f.write(f"TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}, TEST_SNR_LIST={TEST_SNR_LIST}\n")
        f.write(f"SPEC_NFFT={SPEC_NFFT}, SPEC_WIN={SPEC_WIN}, SPEC_HOP={SPEC_HOP}, SPEC_SIZE={SPEC_SIZE}\n")
        f.write(f"workers(train/eval)={NUM_WORKERS_TRAIN}/{NUM_WORKERS_EVAL}\n")

    print(f"[INFO] DEVICE={DEVICE} | AMP={USE_AMP}")
    print(f"[INFO] Classes={num_classes}, TrainSamples={len(train_samples)}, TestSamples={len(test_samples)}, L={L}")
    print(f"[INFO] Train: multipath={AUG_USE_MULTIPATH}, doppler={AUG_USE_DOPPLER}, awgn={AUG_USE_AWGN}")
    print(f"[INFO] Train SNR~U{TRAIN_SNR_DB_RANGE}, Train v~U{TRAIN_V_KMH_RANGE}")
    print(f"[INFO] Test: v={TEST_V_KMH_FIXED} km/h (fd={fd_test:.2f} Hz), SNR sweep={TEST_SNR_LIST}, TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}")
    print(f"[INFO] SaveFolder: {save_folder}")

    snr_to_accs = {snr: [] for snr in TEST_SNR_LIST}

    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    if mode == "file":
        train_files = np.array(train_files, dtype=np.int64)
        kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        fold_iter = kf.split(train_files)

        for fold, (tr_fi_idx, va_fi_idx) in enumerate(fold_iter, 1):
            tr_files = train_files[tr_fi_idx].tolist()
            va_files = train_files[va_fi_idx].tolist()

            tr_samples_fold = [(fi, si) for fi in tr_files for si in range(signals_by_file[fi].shape[0])]
            va_samples_fold = [(fi, si) for fi in va_files for si in range(signals_by_file[fi].shape[0])]

            _run_one_fold_fast(
                fold, save_folder, num_classes,
                signals_by_file, label_idx_by_file,
                tr_samples_fold, va_samples_fold, test_samples,
                snr_to_accs, scaler
            )
    else:
        kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        idx_all = np.arange(len(train_samples))
        for fold, (tr_idx, va_idx) in enumerate(kf.split(idx_all), 1):
            tr_samples_fold = [train_samples[i] for i in tr_idx]
            va_samples_fold = [train_samples[i] for i in va_idx]

            _run_one_fold_fast(
                fold, save_folder, num_classes,
                signals_by_file, label_idx_by_file,
                tr_samples_fold, va_samples_fold, test_samples,
                snr_to_accs, scaler
            )

    print("\n========== Overall Test SNR Sweep (mean±std over folds) ==========")
    rows = []
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std()) if arr.size else 0.0
        print(f"SNR {snr:>3} dB | Acc {mean:.2f} ± {std:.2f}")
        rows.append([snr, mean, std] + snr_to_accs[snr])

    csv_path = os.path.join(save_folder, "test_snr_sweep.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        header = ["snr_db", "acc_mean", "acc_std"] + [f"fold{i}" for i in range(1, N_SPLITS + 1)]
        writer.writerow(header)
        writer.writerows(rows)

    print(f"\n[INFO] All saved in: {save_folder}")
    print(f"[INFO] SNR sweep CSV: {csv_path}")
    return save_folder


def _run_one_fold_fast(
    fold: int,
    save_folder: str,
    num_classes: int,
    signals_by_file,
    label_idx_by_file,
    tr_samples_fold,
    va_samples_fold,
    test_samples,
    snr_to_accs: dict,
    scaler: torch.cuda.amp.GradScaler
):
    print(f"\n========== Fold {fold}/{N_SPLITS} ==========")

    tr_ds = LTEVSiameseIQDataset(signals_by_file, label_idx_by_file, tr_samples_fold)
    va_ds = LTEVSingleIQDataset(signals_by_file, label_idx_by_file, va_samples_fold)
    te_ds = LTEVSingleIQDataset(signals_by_file, label_idx_by_file, test_samples)

    tr_loader = DataLoader(
        tr_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,
        num_workers=NUM_WORKERS_TRAIN, pin_memory=True
    )
    va_loader = DataLoader(
        va_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=True
    )
    te_loader = DataLoader(
        te_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=True
    )

    model = SpecFeatureNet(num_classes=num_classes).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", factor=LR_FACTOR, patience=LR_PATIENCE
    )
    ce = nn.CrossEntropyLoss()

    best_val_loss = float("inf")
    best_state = None
    es_count = 0

    for epoch in range(1, MAX_EPOCHS + 1):
        model.train()
        loss_sum, nb = 0.0, 0

        for iq1, iq2, y in tr_loader:
            iq1 = iq1.to(DEVICE, non_blocking=True)
            iq2 = iq2.to(DEVICE, non_blocking=True)
            y   = y.to(DEVICE, non_blocking=True)

            # 2B 拼起来，一起增强 + 一起 STFT
            iq_cat = torch.cat([iq1, iq2], dim=0)   # (2B,L,2)
            iq_cat = augment_train_batch(iq_cat)    # (2B,L,2)
            spec_cat = iq_to_logspec_batch(iq_cat)  # (2B,1,S,S)
            spec1, spec2 = spec_cat.chunk(2, dim=0)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                z1, p1, z2, p2 = model(spec1, spec2)
                loss_cl = nt_xent_loss(z1, z2, tau=TAU)  # 内部强制 fp32
                loss_ce = 0.5 * (ce(p1, y) + ce(p2, y))
                loss = LAMBDA_CL * loss_cl + LAMBDA_CE * loss_ce

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            loss_sum += float(loss.item())
            nb += 1

        train_loss = loss_sum / max(nb, 1)
        val_loss, val_acc, _ = eval_single_iq(model, va_loader, num_classes, mode="val")

        prev_lr = opt.param_groups[0]["lr"]
        scheduler.step(val_loss)
        cur_lr = opt.param_groups[0]["lr"]
        if cur_lr < prev_lr:
            print(f"[LR DROP] {prev_lr:.2e} -> {cur_lr:.2e} (val_loss={val_loss:.4f})")

        print(f"Epoch {epoch:03d} | LR={cur_lr:.2e} | TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | ValAcc={val_acc:.2f}%")

        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            es_count = 0
        else:
            es_count += 1
            if es_count >= ES_PATIENCE:
                print("[INFO] Early stopping triggered.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(model.state_dict(), os.path.join(save_folder, f"model_fold{fold}.pth"))

    # Test SNR sweep（复用同一个 te_loader，每个 snr 在 eval 内做增强）
    model.eval()
    fold_snr_acc = {}
    for snr in TEST_SNR_LIST:
        _, test_acc, _ = eval_single_iq(model, te_loader, num_classes, mode="test", snr_db=float(snr))
        snr_to_accs[snr].append(test_acc)
        fold_snr_acc[snr] = test_acc

    print("[FOLD TEST] Acc@SNR:", {snr: f"{fold_snr_acc[snr]:.2f}%" for snr in TEST_SNR_LIST})


# ----------------------------
# 9) main
# ----------------------------
if __name__ == "__main__":
    train_kfold_ltev_spec_siamese_fast(DATA_PATH)


In [None]:
# ============================================================
# LTE-V XFR (PER-FILE sequential blocks) + Block-level split (no leakage)
# + Spec Siamese (NT-Xent) + CE head + FAST GPU STFT + AMP
#
# XFR block construction YOU WANT:
#   - For each FILE:
#       take sequential group_size=m samples -> (m, L, 2)
#       drop remainder
#       XFR "flip"/transpose -> (L, m, 2)  (same as your reference)
#
# Split:
#   - block-level train/val/test (no block leakage)
#   - optional STRICT balanced test blocks per class
#
# Positive pair:
#   - "same_block": two different XFR-samples (two rows) within SAME block
#   - "simclr": same XFR-sample, two random augmented views
# ============================================================

import os
import glob
import csv
import time
import random
import numpy as np
from datetime import datetime
from collections import Counter, defaultdict
from contextlib import nullcontext

import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split, StratifiedKFold

# ----------------------------
# 0) Config
# ----------------------------
SEED = 42
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

USE_AMP = (DEVICE.type == "cuda")

# ---- paths ----
DATA_PATH = "E:/rf_datasets/"   # <-- change this
RECURSIVE_GLOB = True

# ---- XFR (block) ----
GROUP_SIZE = 288     # m (你说公平起见 m=L=288)
TEST_SIZE = 0.25
STRICT_TEST_BALANCE = True  # 测试集每类 block 数严格一致

# ---- training ----
BATCH_SIZE = 256
LR = 3e-4
WEIGHT_DECAY = 0.0
MAX_EPOCHS = 200
N_SPLITS = 5

LR_PATIENCE = 10
LR_FACTOR = 0.5
ES_PATIENCE = 30

# Positive pair mode: "same_block" or "simclr"
POS_PAIR_MODE = "same_block"

# ---- Contrastive loss ----
TAU = 0.05
LAMBDA_CL = 1.0
LAMBDA_CE = 1.0  # set 0.0 to disable CE

# ---- augmentation ----
AUG_USE_MULTIPATH = True
AUG_USE_DOPPLER   = True
AUG_USE_AWGN      = True

FS = 5e6
FC = 5.9e9
RMS_DS_NS_RANGE = (5.0, 300.0)
MAX_TAPS = 16

TRAIN_V_KMH_RANGE = (120.0, 120.0)
TRAIN_SNR_DB_RANGE = (-40.0, 20.0)

TEST_V_KMH_FIXED = 120.0
TEST_USE_MULTIPATH = False
TEST_RMS_DS_NS_RANGE = RMS_DS_NS_RANGE
TEST_SNR_LIST = list(range(20, -45, -5))

# ---- spectrogram ----
SPEC_NFFT = 128
SPEC_WIN  = 128
SPEC_HOP  = 16
SPEC_SIZE = 64

# ---- dataloader ----
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_EVAL  = 0

# ---- saving ----
SAVE_ROOT = "./training_results"
os.makedirs(SAVE_ROOT, exist_ok=True)
SCRIPT_NAME = "LTEV_XFR_PerFileSeqBlocks_SpecSiamese"

# ----------------------------
# AMP compatibility (no device_type bug)
# ----------------------------
def amp_autocast(enabled: bool = True):
    if not enabled:
        return nullcontext()
    if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
        try:
            return torch.amp.autocast(device_type=DEVICE.type, enabled=True)
        except TypeError:
            return torch.amp.autocast(DEVICE.type, enabled=True)
    if DEVICE.type == "cuda":
        return torch.cuda.amp.autocast(enabled=True)
    return nullcontext()

def make_grad_scaler(enabled: bool = True):
    if (not enabled) or (DEVICE.type != "cuda"):
        return None
    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        for ctor in (
            lambda: torch.amp.GradScaler("cuda", enabled=True),
            lambda: torch.amp.GradScaler(enabled=True),
        ):
            try:
                return ctor()
            except TypeError:
                pass
    return torch.cuda.amp.GradScaler(enabled=True)

# ----------------------------
# Logging helpers
# ----------------------------
class CSVLogger:
    def __init__(self, path, header):
        self.path = path
        self.header = header
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        with open(self.path, "w", newline="", encoding="utf-8") as f:
            csv.writer(f).writerow(self.header)

    def log_row(self, row):
        with open(self.path, "a", newline="", encoding="utf-8") as f:
            csv.writer(f).writerow(row)

def write_line(path, line):
    with open(path, "a", encoding="utf-8") as f:
        f.write(line.rstrip() + "\n")

# ----------------------------
# 1) HDF5 reader
# ----------------------------
def decode_txid(txid_arr: np.ndarray) -> str:
    txid_arr = txid_arr.flatten()
    chars = []
    for c in txid_arr:
        if int(c) == 0:
            continue
        try:
            chars.append(chr(int(c)))
        except Exception:
            pass
    s = "".join(chars).strip()
    return s if s else "UNKNOWN"

def load_dmrs_complex_from_file(h5f: h5py.File) -> np.ndarray:
    rfDataset = h5f["rfDataset"]
    dmrs_obj = rfDataset["dmrs"]

    if isinstance(dmrs_obj, h5py.Dataset) and dmrs_obj.dtype.fields is not None:
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    elif isinstance(dmrs_obj, h5py.Group):
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    else:
        tmp = dmrs_obj[()]
        if hasattr(tmp, "dtype") and tmp.dtype.fields is not None:
            real = tmp["real"]
            imag = tmp["imag"]
        else:
            raise RuntimeError("Cannot parse dmrs (not compound dataset nor group real/imag).")

    dmrs_complex = np.asarray(real + 1j * imag)
    if dmrs_complex.ndim != 2:
        raise RuntimeError(f"dmrs_complex dim error: {dmrs_complex.shape}")

    # Some files store as (L,N)
    if dmrs_complex.shape[0] <= 2048 and dmrs_complex.shape[1] > dmrs_complex.shape[0]:
        dmrs_complex = dmrs_complex.T

    return dmrs_complex.astype(np.complex64)

def power_normalize_rows(x: np.ndarray, eps=1e-12) -> np.ndarray:
    p = np.mean(np.abs(x)**2, axis=1, keepdims=True)
    return x / (np.sqrt(p) + eps)

def load_ltev_files_as_iq(data_path: str, recursive: bool = True, max_samples_per_file=None):
    if recursive:
        mat_files = glob.glob(os.path.join(data_path, "**", "*.mat"), recursive=True)
    else:
        mat_files = glob.glob(os.path.join(data_path, "*.mat"))

    if len(mat_files) == 0:
        raise RuntimeError(f"No .mat files found: {data_path}")

    X_files = []
    y_files = []
    file_paths = []

    print(f"[INFO] Found {len(mat_files)} .mat files")
    for fp in mat_files:
        try:
            with h5py.File(fp, "r") as f:
                rfDataset = f["rfDataset"]
                txid_arr = np.asarray(rfDataset["txID"][()])
                tx_str = decode_txid(txid_arr)

                dmrs_complex = load_dmrs_complex_from_file(f)
                if max_samples_per_file is not None:
                    dmrs_complex = dmrs_complex[: int(max_samples_per_file)]

                if dmrs_complex.shape[0] == 0:
                    continue

                dmrs_complex = power_normalize_rows(dmrs_complex)

                iq = np.stack([dmrs_complex.real, dmrs_complex.imag], axis=-1).astype(np.float32)  # (N,L,2)
                X_files.append(iq)
                y_files.append(tx_str)
                file_paths.append(fp)
        except Exception as e:
            print(f"[WARN] Skip file due to read/parse error: {fp} | {repr(e)}")

    if len(X_files) == 0:
        raise RuntimeError("All files failed to read or are empty.")

    cnt = Counter(y_files)
    print("[INFO] txID classes:", len(cnt))
    for k, v in sorted(cnt.items(), key=lambda x: (-x[1], x[0])):
        print(f"  {k}: {v} files")

    Ls = [arr.shape[1] for arr in X_files]
    mode_L = Counter(Ls).most_common(1)[0][0]
    print(f"[INFO] DMRS length stats: min={min(Ls)}, max={max(Ls)}, mode={mode_L}")

    # unify length to mode_L
    for i in range(len(X_files)):
        x = X_files[i]  # (N,L,2)
        L = x.shape[1]
        if L == mode_L:
            continue
        if L > mode_L:
            X_files[i] = x[:, :mode_L, :].astype(np.float32)
        else:
            pad = mode_L - L
            X_files[i] = np.pad(x, ((0,0),(0,pad),(0,0)), mode="constant").astype(np.float32)

    return X_files, y_files, file_paths, mode_L

# ----------------------------
# 2) XFR blocks: PER-FILE sequential grouping (NO cross-file)
# ----------------------------
def build_xfr_blocks_per_file_sequential(X_files, y_files, group_size: int):
    """
    For each file:
      chunk = X_file[b*m:(b+1)*m] -> (m, L, 2)
      XFR flip: transpose -> (L, m, 2)
    Returns:
      X_blocks: (num_blocks, L, m, 2)
      y_blocks: (num_blocks,)
      label_to_idx
      block_meta: list[(file_index, block_in_file)]
    """
    label_list = sorted(list(set(y_files)))
    label_to_idx = {lab: i for i, lab in enumerate(label_list)}
    y_idx_by_file = [label_to_idx[s] for s in y_files]

    X_blocks_list = []
    y_blocks_list = []
    block_meta = []

    for fi, Xf in enumerate(X_files):
        N, L, _ = Xf.shape
        nb = N // group_size
        if nb <= 0:
            continue
        for bi in range(nb):
            start = bi * group_size
            end = start + group_size
            chunk = Xf[start:end]  # (m, L, 2)
            # XFR flip as your reference: (L, m, 2)
            xfr_block = np.transpose(chunk, (1, 0, 2)).astype(np.float32)
            X_blocks_list.append(xfr_block)
            y_blocks_list.append(int(y_idx_by_file[fi]))
            block_meta.append((fi, bi))

    if len(X_blocks_list) == 0:
        raise RuntimeError("No blocks generated. Check GROUP_SIZE and data.")

    X_blocks = np.stack(X_blocks_list, axis=0)  # (B, L, m, 2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    print(f"[INFO] XFR blocks(per-file) generated: num_blocks={X_blocks.shape[0]}, "
          f"L={X_blocks.shape[1]}, m={X_blocks.shape[2]}")
    # class distribution on blocks
    print("[INFO] Block counts per class:", dict(sorted(Counter(y_blocks.tolist()).items())))
    return X_blocks, y_blocks, label_to_idx, block_meta

# ----------------------------
# 3) Block split (optional strict balanced test)
# ----------------------------
def balanced_block_split(y_blocks: np.ndarray, test_size=0.25, seed=SEED, strict_balance=True):
    rng = np.random.default_rng(seed)
    cls_to_idx = defaultdict(list)
    for i, y in enumerate(y_blocks.tolist()):
        cls_to_idx[int(y)].append(i)

    classes = sorted(cls_to_idx.keys())
    counts = {c: len(cls_to_idx[c]) for c in classes}
    print("[INFO] blocks per class (raw):", dict(sorted(counts.items())))

    if not strict_balance:
        idx = np.arange(len(y_blocks), dtype=np.int64)
        tr, te = train_test_split(idx, test_size=test_size, stratify=y_blocks, random_state=seed)
        return idx, tr, te

    Bmin = min(counts.values())
    if Bmin < 2:
        raise RuntimeError(f"Not enough blocks per class to split: Bmin={Bmin}")

    test_k = int(np.floor(Bmin * test_size))
    test_k = max(1, min(Bmin - 1, test_k))

    kept, trainval, test = [], [], []
    for c in classes:
        arr = np.array(cls_to_idx[c], dtype=np.int64)
        rng.shuffle(arr)
        arr = arr[:Bmin]           # trim each class to Bmin
        kept.extend(arr.tolist())
        test.extend(arr[:test_k].tolist())
        trainval.extend(arr[test_k:].tolist())

    kept = np.array(kept, dtype=np.int64)
    trainval = np.array(trainval, dtype=np.int64)
    test = np.array(test, dtype=np.int64)

    print(f"[INFO] STRICT balance: Bmin={Bmin}, test_k={test_k}")
    print("[INFO] blocks/class trainval:", dict(sorted(Counter(y_blocks[trainval].tolist()).items())))
    print("[INFO] blocks/class test   :", dict(sorted(Counter(y_blocks[test].tolist()).items())))
    return kept, trainval, test

def check_block_overlap(train_idx, val_idx, test_idx):
    s_tr = set(train_idx.tolist()) if hasattr(train_idx, "tolist") else set(train_idx)
    s_va = set(val_idx.tolist()) if hasattr(val_idx, "tolist") else set(val_idx)
    s_te = set(test_idx.tolist()) if hasattr(test_idx, "tolist") else set(test_idx)
    if (s_tr & s_va) or (s_tr & s_te) or (s_va & s_te):
        raise RuntimeError("[ERROR] Block overlap detected among train/val/test.")
    print("[INFO] Block overlap check passed.")

# ----------------------------
# 4) Datasets (XFR blocks)
# ----------------------------
class XFRTrainPairDataset(Dataset):
    """
    X_blocks: (num_blocks, L, m, 2)
    One item corresponds to (block_id, t_idx), where t_idx in [0, L).
    Each "sample" is X_blocks[block_id, t_idx] -> (m,2).
    """
    def __init__(self, X_blocks, y_blocks, block_indices, pos_pair_mode="same_block"):
        self.X = X_blocks
        self.y = y_blocks
        self.block_indices = np.array(block_indices, dtype=np.int64)
        self.L = int(X_blocks.shape[1])
        self.pos_mode = pos_pair_mode

        self.sample_list = []
        for bi in self.block_indices.tolist():
            for t in range(self.L):
                self.sample_list.append((int(bi), int(t)))

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        bi, t1 = self.sample_list[idx]
        y = int(self.y[bi])

        iq1 = self.X[bi, t1]  # (m,2)

        if self.pos_mode == "same_block":
            t2 = random.randrange(self.L)
            while t2 == t1 and self.L > 1:
                t2 = random.randrange(self.L)
            iq2 = self.X[bi, t2]  # (m,2)
        else:
            iq2 = iq1.copy()

        return iq1.astype(np.float32), iq2.astype(np.float32), np.int64(y)

class XFRSingleDataset(Dataset):
    def __init__(self, X_blocks, y_blocks, block_indices):
        self.X = X_blocks
        self.y = y_blocks
        self.block_indices = np.array(block_indices, dtype=np.int64)
        self.L = int(X_blocks.shape[1])

        self.sample_list = []
        for bi in self.block_indices.tolist():
            for t in range(self.L):
                self.sample_list.append((int(bi), int(t)))

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        bi, t = self.sample_list[idx]
        y = int(self.y[bi])
        iq = self.X[bi, t]  # (m,2)
        return iq.astype(np.float32), np.int64(y)

# ----------------------------
# 5) GPU augmentations
# ----------------------------
def _to_complex(iq_b: torch.Tensor) -> torch.Tensor:
    return iq_b[..., 0].to(torch.float32) + 1j * iq_b[..., 1].to(torch.float32)

def _from_complex(sig: torch.Tensor) -> torch.Tensor:
    return torch.stack([sig.real, sig.imag], dim=-1)

def batch_normalize_power(iq_b: torch.Tensor) -> torch.Tensor:
    power = (iq_b[..., 0] ** 2 + iq_b[..., 1] ** 2).mean(dim=1, keepdim=True) + 1e-12
    scale = torch.rsqrt(power).unsqueeze(-1)
    return iq_b * scale

def batch_apply_doppler(iq_b: torch.Tensor, v_kmh: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)
    c = 3e8
    v = v_kmh / 3.6
    fd = (v / c) * FC
    n = torch.arange(L, device=iq_b.device, dtype=torch.float32).unsqueeze(0)
    phase = torch.exp(1j * 2.0 * np.pi * fd.unsqueeze(1).to(torch.float32) * n / FS)
    sig = sig * phase
    return _from_complex(sig)

def _grouped_conv1d_real(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    B, _, L = x.shape
    x2 = x.permute(1, 0, 2).contiguous()
    y2 = F.conv1d(x2, w, padding=w.shape[-1] - 1, groups=B)
    return y2.squeeze(0)

def batch_apply_multipath(iq_b: torch.Tensor, rms_ns: torch.Tensor, max_taps: int = MAX_TAPS) -> torch.Tensor:
    B, L, _ = iq_b.shape
    device = iq_b.device

    rms_s = rms_ns * 1e-9
    rms_samples = (rms_s * FS).clamp(min=1e-3)

    k = torch.arange(max_taps, device=device, dtype=torch.float32).unsqueeze(0)
    p = torch.exp(-k / rms_samples.unsqueeze(1))
    p = p / (p.sum(dim=1, keepdim=True) + 1e-12)

    hr = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)
    hi = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)

    hpow = (hr**2 + hi**2).sum(dim=1, keepdim=True) + 1e-12
    norm = torch.rsqrt(hpow)
    hr = hr * norm
    hi = hi * norm

    xr = iq_b[..., 0]
    xi = iq_b[..., 1]
    xr_ = xr.unsqueeze(1)
    xi_ = xi.unsqueeze(1)
    hr_ = hr.unsqueeze(1)
    hi_ = hi.unsqueeze(1)

    xr_hr = _grouped_conv1d_real(xr_, hr_)
    xi_hi = _grouped_conv1d_real(xi_, hi_)
    xr_hi = _grouped_conv1d_real(xr_, hi_)
    xi_hr = _grouped_conv1d_real(xi_, hr_)

    yr = xr_hr - xi_hi
    yi = xr_hi + xi_hr
    yr = yr[:, :L]
    yi = yi[:, :L]
    return torch.stack([yr, yi], dim=-1)

def batch_add_awgn(iq_b: torch.Tensor, snr_db: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)
    p = (sig.real**2 + sig.imag**2).mean(dim=1) + 1e-12
    npow = p / (10.0 ** (snr_db / 10.0))
    std = torch.sqrt(npow / 2.0).unsqueeze(1)
    noise = std * (torch.randn(B, L, device=iq_b.device) + 1j * torch.randn(B, L, device=iq_b.device))
    sig = sig + noise
    return _from_complex(sig)

def augment_train_batch(iq_b: torch.Tensor) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if AUG_USE_MULTIPATH:
        rms = (RMS_DS_NS_RANGE[1] - RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    if AUG_USE_DOPPLER:
        vmin, vmax = TRAIN_V_KMH_RANGE
        if abs(vmax - vmin) < 1e-12:
            v = torch.full((B,), float(vmin), device=iq_b.device)
        else:
            v = (vmax - vmin) * torch.rand(B, device=iq_b.device) + vmin
        iq_b = batch_apply_doppler(iq_b, v)

    if AUG_USE_AWGN:
        smin, smax = TRAIN_SNR_DB_RANGE
        snr = (smax - smin) * torch.rand(B, device=iq_b.device) + smin
        iq_b = batch_add_awgn(iq_b, snr)

    return iq_b

def augment_test_batch(iq_b: torch.Tensor, snr_db: float) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if TEST_USE_MULTIPATH:
        rms = (TEST_RMS_DS_NS_RANGE[1] - TEST_RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + TEST_RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    v = torch.full((B,), float(TEST_V_KMH_FIXED), device=iq_b.device)
    iq_b = batch_apply_doppler(iq_b, v)

    snr = torch.full((B,), float(snr_db), device=iq_b.device)
    iq_b = batch_add_awgn(iq_b, snr)
    return iq_b

# ----------------------------
# 6) GPU STFT
# ----------------------------
_WINDOW_CACHE = {}
def get_hann_window(device: torch.device):
    key = (device.type, device.index, SPEC_WIN)
    if key not in _WINDOW_CACHE:
        _WINDOW_CACHE[key] = torch.hann_window(SPEC_WIN, periodic=True, device=device)
    return _WINDOW_CACHE[key]

def iq_to_logspec_batch(iq_b: torch.Tensor) -> torch.Tensor:
    sig = _to_complex(iq_b).to(torch.complex64)  # (B,L)
    win = get_hann_window(iq_b.device)
    S = torch.stft(
        sig,
        n_fft=SPEC_NFFT,
        hop_length=SPEC_HOP,
        win_length=SPEC_WIN,
        window=win,
        center=True,
        return_complex=True
    )  # (B,F,T)

    mag = torch.abs(S) + 1e-12
    logmag = torch.log(mag)

    mu = logmag.mean(dim=(1, 2), keepdim=True)
    sd = logmag.std(dim=(1, 2), keepdim=True) + 1e-6
    logmag = (logmag - mu) / sd
    try:
        logmag = torch.nan_to_num(logmag, nan=0.0, posinf=0.0, neginf=0.0)
    except Exception:
        logmag[~torch.isfinite(logmag)] = 0.0

    x = logmag.unsqueeze(1)  # (B,1,F,T)
    x = F.interpolate(x, size=(SPEC_SIZE, SPEC_SIZE), mode="bilinear", align_corners=False)
    return x

# ----------------------------
# 7) Model
# ----------------------------
class BasicBlock2D(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)

        self.down = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(identity)
        return self.relu(out + identity)

class SpecFeatureNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.relu  = nn.ReLU(inplace=True)

        self.b1 = BasicBlock2D(32, 32, stride=1)
        self.b2 = BasicBlock2D(32, 32, stride=1)
        self.b3 = BasicBlock2D(32, 64, stride=1)
        self.b4 = BasicBlock2D(64, 64, stride=1)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(64, 512)
        self.fc2 = nn.Linear(512, 256)
        self.cls = nn.Linear(256, num_classes)

    def forward_once(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)
        x = F.relu(self.fc1(x))
        z = self.fc2(x)
        logits = self.cls(z)
        return z, logits

    def forward(self, x1, x2=None):
        z1, p1 = self.forward_once(x1)
        if x2 is None:
            return z1, p1
        z2, p2 = self.forward_once(x2)
        return z1, p1, z2, p2

# ----------------------------
# 8) Loss + Eval
# ----------------------------
def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    with amp_autocast(enabled=False):
        z1 = z1.float()
        z2 = z2.float()
        N = z1.size(0)
        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, dim=1)
        sim = (z @ z.T) / float(tau)
        sim.fill_diagonal_(torch.finfo(sim.dtype).min)
        pos = torch.arange(2 * N, device=z.device)
        pos = (pos + N) % (2 * N)
        log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
        loss = -log_prob[torch.arange(2 * N, device=z.device), pos]
        return loss.mean()

@torch.no_grad()
def eval_single_iq(model: SpecFeatureNet, loader: DataLoader, num_classes: int, mode: str, snr_db: float = None):
    model.eval()
    ce = nn.CrossEntropyLoss()
    total, correct = 0, 0
    loss_sum, nb = 0.0, 0

    for iq, y in loader:
        iq = iq.to(DEVICE, non_blocking=True)
        y  = y.to(DEVICE, non_blocking=True)

        if mode == "test":
            iq = augment_test_batch(iq, snr_db=float(snr_db))

        spec = iq_to_logspec_batch(iq)
        _, logits = model(spec, None)
        loss = ce(logits, y)

        loss_sum += float(loss.item())
        nb += 1
        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    acc = 100.0 * correct / max(total, 1)
    return (loss_sum / max(nb, 1)), acc

# ----------------------------
# 9) Train one fold
# ----------------------------
def run_one_fold(
    fold: int,
    save_folder: str,
    results_txt: str,
    num_classes: int,
    X_blocks: np.ndarray,
    y_blocks: np.ndarray,
    tr_blocks: np.ndarray,
    va_blocks: np.ndarray,
    test_loader: DataLoader,
    snr_to_accs: dict,
    scaler
):
    write_line(results_txt, f"\n========== Fold {fold}/{N_SPLITS} ==========")

    tr_ds = XFRTrainPairDataset(X_blocks, y_blocks, tr_blocks, pos_pair_mode=POS_PAIR_MODE)
    va_ds = XFRSingleDataset(X_blocks, y_blocks, va_blocks)

    tr_loader = DataLoader(
        tr_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,
        num_workers=NUM_WORKERS_TRAIN, pin_memory=(DEVICE.type == "cuda")
    )
    va_loader = DataLoader(
        va_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=(DEVICE.type == "cuda")
    )

    model = SpecFeatureNet(num_classes=num_classes).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", factor=LR_FACTOR, patience=LR_PATIENCE
    )
    ce = nn.CrossEntropyLoss()

    best_val_loss = float("inf")
    best_state = None
    es_count = 0
    best_epoch = 0

    fold_log_path = os.path.join(save_folder, f"fold{fold}_trainlog.csv")
    fold_logger = CSVLogger(
        fold_log_path,
        header=["epoch", "lr", "train_loss", "val_loss", "val_acc", "best_val_loss", "es_count", "epoch_time_sec"]
    )

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()
        model.train()
        loss_sum, nb = 0.0, 0

        for iq1_np, iq2_np, y_np in tr_loader:
            iq1 = iq1_np.to(DEVICE, non_blocking=True)  # (B,m,2)
            iq2 = iq2_np.to(DEVICE, non_blocking=True)
            y   = y_np.to(DEVICE, non_blocking=True)

            if POS_PAIR_MODE == "simclr":
                v1 = augment_train_batch(iq1)
                v2 = augment_train_batch(iq1)
            else:
                cat = torch.cat([iq1, iq2], dim=0)
                cat = augment_train_batch(cat)
                v1, v2 = cat.chunk(2, dim=0)

            spec1 = iq_to_logspec_batch(v1)
            spec2 = iq_to_logspec_batch(v2)

            opt.zero_grad(set_to_none=True)

            with amp_autocast(enabled=USE_AMP):
                z1, p1, z2, p2 = model(spec1, spec2)
                loss_cl = nt_xent_loss(z1, z2, tau=TAU)
                if LAMBDA_CE > 0.0:
                    loss_ce = 0.5 * (ce(p1, y) + ce(p2, y))
                else:
                    loss_ce = torch.zeros((), device=DEVICE)
                loss = LAMBDA_CL * loss_cl + LAMBDA_CE * loss_ce

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                opt.step()

            loss_sum += float(loss.item())
            nb += 1

        train_loss = loss_sum / max(nb, 1)
        val_loss, val_acc = eval_single_iq(model, va_loader, num_classes=num_classes, mode="val")

        prev_lr = opt.param_groups[0]["lr"]
        scheduler.step(val_loss)
        cur_lr = opt.param_groups[0]["lr"]

        epoch_time = time.time() - t0
        msg = (f"Epoch {epoch:03d} | LR={cur_lr:.2e} | "
               f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | ValAcc={val_acc:.2f}% | "
               f"BestValLoss={best_val_loss:.4f} | ES={es_count}/{ES_PATIENCE}")
        print(msg)
        write_line(results_txt, msg)
        fold_logger.log_row([epoch, cur_lr, train_loss, val_loss, val_acc, best_val_loss, es_count, epoch_time])

        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_epoch = epoch
            es_count = 0
        else:
            es_count += 1
            if es_count >= ES_PATIENCE:
                msg = "[INFO] Early stopping triggered."
                print(msg)
                write_line(results_txt, msg)
                break

    torch.save(model.state_dict(), os.path.join(save_folder, f"model_fold{fold}.pth"))
    if best_state is not None:
        model.load_state_dict(best_state)
        torch.save(model.state_dict(), os.path.join(save_folder, f"best_model_fold{fold}.pth"))

    write_line(results_txt, f"[FOLD {fold}] BestEpoch={best_epoch}, BestValLoss={best_val_loss:.6f}")

    # Test sweep
    fold_test_csv = os.path.join(save_folder, f"fold{fold}_test_snr.csv")
    fold_test_logger = CSVLogger(fold_test_csv, header=["snr_db", "test_loss", "test_acc"])

    fold_snr_acc = {}
    for snr in TEST_SNR_LIST:
        test_loss, test_acc = eval_single_iq(model, test_loader, num_classes=num_classes, mode="test", snr_db=float(snr))
        snr_to_accs[snr].append(test_acc)
        fold_snr_acc[snr] = test_acc
        fold_test_logger.log_row([snr, test_loss, test_acc])

    msg = "[FOLD TEST] " + ", ".join([f"{snr}:{fold_snr_acc[snr]:.2f}%" for snr in TEST_SNR_LIST])
    print(msg)
    write_line(results_txt, msg)

# ----------------------------
# 10) Main pipeline
# ----------------------------
def train_kfold_ltev_xfr_per_file_seq(data_path: str):
    X_files, y_files, file_paths, L = load_ltev_files_as_iq(data_path, recursive=RECURSIVE_GLOB)

    # Optional fairness: enforce m == L if you want exactly your setting
    if GROUP_SIZE != L:
        print(f"[WARN] You set GROUP_SIZE(m)={GROUP_SIZE} but DMRS length L={L}. "
              f"After XFR transpose, each sample length will be m={GROUP_SIZE}, not L. "
              f"If you want m=L, set GROUP_SIZE={L}.")

    X_blocks, y_blocks, label_to_idx, block_meta = build_xfr_blocks_per_file_sequential(
        X_files, y_files, group_size=GROUP_SIZE
    )
    num_classes = len(label_to_idx)

    kept_idx, trainval_idx, test_idx = balanced_block_split(
        y_blocks, test_size=TEST_SIZE, seed=SEED, strict_balance=STRICT_TEST_BALANCE
    )

    # If strict balance: keep only kept blocks (trim each class to Bmin)
    if STRICT_TEST_BALANCE:
        X_blocks = X_blocks[kept_idx]
        y_blocks = y_blocks[kept_idx]
        block_meta = [block_meta[i] for i in kept_idx.tolist()]

        old_to_new = {int(old): i for i, old in enumerate(kept_idx.tolist())}
        trainval_idx = np.array([old_to_new[int(i)] for i in trainval_idx.tolist()], dtype=np.int64)
        test_idx = np.array([old_to_new[int(i)] for i in test_idx.tolist()], dtype=np.int64)

    # Save folder
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_dir = f"{timestamp}_{SCRIPT_NAME}_m{GROUP_SIZE}_TestBal{int(STRICT_TEST_BALANCE)}_{POS_PAIR_MODE}"
    save_folder = os.path.join(SAVE_ROOT, save_dir)
    os.makedirs(save_folder, exist_ok=True)
    results_txt = os.path.join(save_folder, "results.txt")

    # Config dump
    fd_test = (TEST_V_KMH_FIXED / 3.6) / 3e8 * FC
    with open(os.path.join(save_folder, "config.txt"), "w", encoding="utf-8") as f:
        f.write(f"DEVICE={DEVICE}\nAMP={USE_AMP}\n")
        f.write(f"DATA_PATH={data_path}\nRECURSIVE_GLOB={RECURSIVE_GLOB}\n")
        f.write(f"DMRS_LEN(L)={L}\n")
        f.write(f"GROUP_SIZE(m)={GROUP_SIZE}\n")
        f.write(f"X_blocks shape={X_blocks.shape} (num_blocks, L, m, 2)\n")
        f.write(f"num_classes={num_classes}\n")
        f.write(f"BLOCK_SPLIT=True | TEST_SIZE={TEST_SIZE} | STRICT_TEST_BALANCE={STRICT_TEST_BALANCE}\n")
        f.write(f"POS_PAIR_MODE={POS_PAIR_MODE}\n")
        f.write(f"BATCH_SIZE={BATCH_SIZE}, LR={LR}, WEIGHT_DECAY={WEIGHT_DECAY}\n")
        f.write(f"MAX_EPOCHS={MAX_EPOCHS}, N_SPLITS={N_SPLITS}\n")
        f.write(f"LR_PATIENCE={LR_PATIENCE}, LR_FACTOR={LR_FACTOR}, ES_PATIENCE={ES_PATIENCE}\n")
        f.write(f"TAU={TAU}, LAMBDA_CL={LAMBDA_CL}, LAMBDA_CE={LAMBDA_CE}\n")
        f.write(f"FS={FS}, FC={FC}\n")
        f.write(f"AUG_USE_MULTIPATH={AUG_USE_MULTIPATH}, RMS_DS_NS_RANGE={RMS_DS_NS_RANGE}, MAX_TAPS={MAX_TAPS}\n")
        f.write(f"AUG_USE_DOPPLER={AUG_USE_DOPPLER}, TRAIN_V_KMH_RANGE={TRAIN_V_KMH_RANGE}\n")
        f.write(f"AUG_USE_AWGN={AUG_USE_AWGN}, TRAIN_SNR_DB_RANGE={TRAIN_SNR_DB_RANGE}\n")
        f.write(f"TEST_V_KMH_FIXED={TEST_V_KMH_FIXED}, fd_test={fd_test}\n")
        f.write(f"TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}, TEST_SNR_LIST={TEST_SNR_LIST}\n")
        f.write(f"SPEC_NFFT={SPEC_NFFT}, SPEC_WIN={SPEC_WIN}, SPEC_HOP={SPEC_HOP}, SPEC_SIZE={SPEC_SIZE}\n")

    print(f"[INFO] DEVICE={DEVICE} | AMP={USE_AMP}")
    print(f"[INFO] X_blocks shape={X_blocks.shape} (num_blocks, L, m, 2)")
    print(f"[INFO] Classes={num_classes}, TrainValBlocks={len(trainval_idx)}, TestBlocks={len(test_idx)}")
    print(f"[INFO] SaveFolder: {save_folder}")

    write_line(results_txt, f"DEVICE={DEVICE}, AMP={USE_AMP}")
    write_line(results_txt, f"X_blocks shape={X_blocks.shape} (num_blocks, L, m, 2)")
    write_line(results_txt, f"TrainValBlocks={len(trainval_idx)}, TestBlocks={len(test_idx)}")
    write_line(results_txt, f"POS_PAIR_MODE={POS_PAIR_MODE}")
    write_line(results_txt, "-" * 80)

    # Fixed test loader
    test_ds = XFRSingleDataset(X_blocks, y_blocks, test_idx)
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=(DEVICE.type == "cuda")
    )

    # KFold on trainval blocks (stratified by block label)
    y_trainval = y_blocks[trainval_idx]
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

    check_block_overlap(trainval_idx, [], test_idx)

    snr_to_accs = {snr: [] for snr in TEST_SNR_LIST}
    scaler = make_grad_scaler(enabled=USE_AMP)

    for fold, (tr_pos, va_pos) in enumerate(skf.split(trainval_idx, y_trainval), 1):
        tr_blocks = trainval_idx[tr_pos]
        va_blocks = trainval_idx[va_pos]
        check_block_overlap(tr_blocks, va_blocks, test_idx)

        c_tr = Counter(y_blocks[tr_blocks].tolist())
        c_va = Counter(y_blocks[va_blocks].tolist())
        print(f"\n========== Fold {fold}/{N_SPLITS} ==========")
        print(f"[FOLD {fold}] train_blocks/class:", dict(sorted(c_tr.items())))
        print(f"[FOLD {fold}] val_blocks/class  :", dict(sorted(c_va.items())))

        run_one_fold(
            fold=fold,
            save_folder=save_folder,
            results_txt=results_txt,
            num_classes=num_classes,
            X_blocks=X_blocks,
            y_blocks=y_blocks,
            tr_blocks=tr_blocks,
            va_blocks=va_blocks,
            test_loader=test_loader,
            snr_to_accs=snr_to_accs,
            scaler=scaler
        )

    # Summarize sweep
    rows = []
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std()) if arr.size else 0.0
        rows.append([snr, mean, std] + snr_to_accs[snr])

    csv_path = os.path.join(save_folder, "test_snr_sweep.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        header = ["snr_db", "acc_mean", "acc_std"] + [f"fold{i}" for i in range(1, N_SPLITS + 1)]
        writer.writerow(header)
        writer.writerows(rows)

    write_line(results_txt, "\n========== Overall Test SNR Sweep (mean±std over folds) ==========")
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std()) if arr.size else 0.0
        write_line(results_txt, f"SNR {snr:>3} dB | Acc {mean:.2f} ± {std:.2f}")

    print(f"\n[INFO] All saved in: {save_folder}")
    print(f"[INFO] SNR sweep CSV: {csv_path}")
    return save_folder

# ----------------------------
# 11) main
# ----------------------------
if __name__ == "__main__":
    train_kfold_ltev_xfr_per_file_seq(DATA_PATH)
