In [None]:
# ============================================================
# WiSig (ManySig.pkl) - Spec + Siamese (FAST GPU PIPELINE + LOGGING)
#
# 保存内容：
#  - config.txt
#  - results.txt
#  - fold{K}_trainlog.csv              (每 epoch 训练日志)
#  - fold{K}_test_snr.csv              (每 fold 的 SNR sweep 结果)
#  - test_snr_sweep.csv                (跨 folds 汇总 mean/std + fold1..foldN)
#  - model_fold{K}.pth                 (最后/early-stop 时模型)
#  - best_model_fold{K}.pth            (val_loss 最佳)
#
# 训练增强（每样本随机）：
#  - Multipath: 指数 PDP TDL, RMS DS ~ U[5,300] ns
#  - Doppler: v ~ U[0,120] km/h
#  - AWGN: SNR ~ U[-40,20] dB
#
# 测试：
#  - Doppler 固定 120 km/h
#  - AWGN SNR sweep: 20,15,...,-40 dB
#  - 默认不额外加 multipath（TEST_USE_MULTIPATH 控制）
#
# 加速点：
#  - Dataset 只输出 IQ（不在 __getitem__ 做 STFT）
#  - 增强 + STFT 全部 GPU batch
#  - AMP 混合精度（NT-Xent 强制 fp32 防溢出）
# ============================================================

import os
import csv
import time
import random
import numpy as np
from datetime import datetime
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix

from data_utilities import load_compact_pkl_dataset

# ----------------------------
# 0) 全局配置
# ----------------------------
SEED = 42
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

USE_AMP = (DEVICE.type == "cuda")

dataset_name = "ManySig"
dataset_path = "../ManySig.pkl/"  # 按你工程路径
equalized = 0
max_sig = None  # None 全部；或限制每个 (tx,rx,date) 使用前 max_sig 条

train_dates = ["2021_03_15"]
test_dates  = ["2021_03_01"]

# 训练超参
BATCH_SIZE = 256           # 过大可能 OOM（NT-Xent 的 sim 矩阵随 batch^2）
LR = 3e-4
WEIGHT_DECAY = 0.0
MAX_EPOCHS = 200
N_SPLITS = 5
LR_PATIENCE = 10
LR_FACTOR = 0.5
ES_PATIENCE = 30

TAU = 0.05
LAMBDA_CL = 1.0
LAMBDA_CE = 1.0

# RF 参数（与你 WiSig XFR 一致）
FS = 20e6
FC = 2.4e9

# 训练增强
AUG_USE_MULTIPATH = True
AUG_USE_DOPPLER   = True
AUG_USE_AWGN      = True
RMS_DS_NS_RANGE = (5.0, 300.0)
TRAIN_V_KMH_RANGE = (0.0, 120.0)
TRAIN_SNR_DB_RANGE = (-40.0, 20.0)

# 测试增强
TEST_V_KMH_FIXED = 120.0
TEST_USE_MULTIPATH = False
TEST_RMS_DS_NS_RANGE = RMS_DS_NS_RANGE
TEST_SNR_LIST = list(range(20, -45, -5))

# spectrogram
SPEC_NFFT = 128
SPEC_WIN  = 128
SPEC_HOP  = 16
SPEC_SIZE = 64

# multipath taps
MAX_TAPS = 16

# DataLoader workers（Windows 建议先用 0）
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_EVAL  = 0

SAVE_ROOT = "./training_results"
os.makedirs(SAVE_ROOT, exist_ok=True)
SCRIPT_NAME = "WiSig_SpecSiamese_FASTGPU_SNRsweep_Doppler120"

RETURN_CM = False  # 如需 confusion matrix 可打开（更慢）


# ----------------------------
# 1) 工具：日志写入
# ----------------------------
class CSVLogger:
    def __init__(self, path, header):
        self.path = path
        self.header = header
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        with open(self.path, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(self.header)

    def log_row(self, row):
        with open(self.path, "a", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(row)


def write_line(path, line):
    with open(path, "a", encoding="utf-8") as f:
        f.write(line.rstrip() + "\n")


# ----------------------------
# 2) WiSig index 构建
# ----------------------------
def build_index_list(compact_dataset, tx_names, dates, equalized=0, max_sig=None):
    eq_i = compact_dataset["equalized_list"].index(equalized)

    tx_i_list = []
    for name in tx_names:
        if name in compact_dataset["tx_list"]:
            tx_i_list.append(compact_dataset["tx_list"].index(name))
    tx_i_to_label = {tx_i: j for j, tx_i in enumerate(tx_i_list)}

    index_list = []
    for tx_i in tx_i_list:
        for date in dates:
            if date not in compact_dataset["capture_date_list"]:
                continue
            date_i = compact_dataset["capture_date_list"].index(date)
            for rx_i in range(len(compact_dataset["rx_list"])):
                seq = compact_dataset["data"][tx_i][rx_i][date_i][eq_i]
                n = len(seq) if max_sig is None else min(len(seq), max_sig)
                for k in range(n):
                    index_list.append((tx_i, rx_i, date_i, k))
    return index_list, tx_i_to_label


# ----------------------------
# 3) Dataset：只返回 IQ（Siamese/Single）
# ----------------------------
class WiSigSiameseIQDataset(Dataset):
    def __init__(self, compact_dataset, index_list, tx_i_to_label, equalized=0, max_sig=None):
        self.ds = compact_dataset
        self.index_list = index_list
        self.tx_i_to_label = tx_i_to_label
        self.eq_i = compact_dataset["equalized_list"].index(equalized)
        self.max_sig = max_sig

        # (tx_i,date_i) -> rx_i -> n
        self.len_map = defaultdict(dict)
        for (tx_i, rx_i, date_i, _) in index_list:
            if rx_i in self.len_map[(tx_i, date_i)]:
                continue
            seq = self.ds["data"][tx_i][rx_i][date_i][self.eq_i]
            n = len(seq) if max_sig is None else min(len(seq), max_sig)
            if n > 0:
                self.len_map[(tx_i, date_i)][rx_i] = n
        self.rx_choices = {k: list(v.keys()) for k, v in self.len_map.items()}

    def __len__(self):
        return len(self.index_list)

    def _get_iq(self, tx_i, rx_i, date_i, k):
        sig = self.ds["data"][tx_i][rx_i][date_i][self.eq_i][k]
        return np.asarray(sig, dtype=np.float32)  # (L,2)

    def __getitem__(self, idx):
        tx_i, rx1_i, date_i, k1 = self.index_list[idx]
        label = self.tx_i_to_label[tx_i]

        rx_list = self.rx_choices[(tx_i, date_i)]
        if len(rx_list) < 2:
            rx2_i = rx1_i
        else:
            rx2_i = rx1_i
            while rx2_i == rx1_i:
                rx2_i = random.choice(rx_list)

        n2 = self.len_map[(tx_i, date_i)][rx2_i]
        k2 = k1 if k1 < n2 else random.randrange(n2)

        iq1 = self._get_iq(tx_i, rx1_i, date_i, k1)
        iq2 = self._get_iq(tx_i, rx2_i, date_i, k2)

        return iq1, iq2, np.int64(label)

class WiSigSingleIQDataset(Dataset):
    def __init__(self, compact_dataset, index_list, tx_i_to_label, equalized=0):
        self.ds = compact_dataset
        self.index_list = index_list
        self.tx_i_to_label = tx_i_to_label
        self.eq_i = compact_dataset["equalized_list"].index(equalized)

    def __len__(self):
        return len(self.index_list)

    def __getitem__(self, idx):
        tx_i, rx_i, date_i, k = self.index_list[idx]
        sig = self.ds["data"][tx_i][rx_i][date_i][self.eq_i][k]
        iq = np.asarray(sig, dtype=np.float32)
        y = np.int64(self.tx_i_to_label[tx_i])
        return iq, y


# ----------------------------
# 4) GPU batch 增强：Multipath / Doppler / AWGN
# ----------------------------
def _to_complex(iq_b: torch.Tensor) -> torch.Tensor:
    return iq_b[..., 0].to(torch.float32) + 1j * iq_b[..., 1].to(torch.float32)

def _from_complex(sig: torch.Tensor) -> torch.Tensor:
    return torch.stack([sig.real, sig.imag], dim=-1)

def batch_normalize_power(iq_b: torch.Tensor) -> torch.Tensor:
    # power: (B,1) -> scale: (B,1,1)
    power = (iq_b[..., 0] ** 2 + iq_b[..., 1] ** 2).mean(dim=1, keepdim=True) + 1e-12
    scale = torch.rsqrt(power).unsqueeze(-1)
    return iq_b * scale

def batch_apply_doppler(iq_b: torch.Tensor, v_kmh: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)

    c = 3e8
    v = v_kmh / 3.6
    fd = (v / c) * FC  # (B,)
    n = torch.arange(L, device=iq_b.device, dtype=torch.float32).unsqueeze(0)  # (1,L)
    phase = torch.exp(1j * 2.0 * np.pi * fd.unsqueeze(1).to(torch.float32) * n / FS)  # (B,L)
    sig = sig * phase
    return _from_complex(sig)

def _grouped_conv1d_real(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    # x: (B,1,L), w: (B,1,K)
    B, _, L = x.shape
    _, _, K = w.shape
    x2 = x.permute(1, 0, 2).contiguous()      # (1,B,L)
    y2 = F.conv1d(x2, w, padding=K - 1, groups=B)  # (1,B,L+K-1)
    return y2.squeeze(0)  # (B, L+K-1)

def batch_apply_multipath(iq_b: torch.Tensor, rms_ns: torch.Tensor, max_taps: int = MAX_TAPS) -> torch.Tensor:
    B, L, _ = iq_b.shape
    device = iq_b.device

    rms_s = rms_ns * 1e-9
    rms_samples = (rms_s * FS).clamp(min=1e-3)  # (B,)

    k = torch.arange(max_taps, device=device, dtype=torch.float32).unsqueeze(0)  # (1,K)
    p = torch.exp(-k / rms_samples.unsqueeze(1))  # (B,K)
    p = p / (p.sum(dim=1, keepdim=True) + 1e-12)

    hr = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)
    hi = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)

    hpow = (hr**2 + hi**2).sum(dim=1, keepdim=True) + 1e-12
    norm = torch.rsqrt(hpow)
    hr = hr * norm
    hi = hi * norm

    xr = iq_b[..., 0]
    xi = iq_b[..., 1]

    xr_ = xr.unsqueeze(1)
    xi_ = xi.unsqueeze(1)
    hr_ = hr.unsqueeze(1)
    hi_ = hi.unsqueeze(1)

    xr_hr = _grouped_conv1d_real(xr_, hr_)
    xi_hi = _grouped_conv1d_real(xi_, hi_)
    xr_hi = _grouped_conv1d_real(xr_, hi_)
    xi_hr = _grouped_conv1d_real(xi_, hr_)

    yr = xr_hr - xi_hi
    yi = xr_hi + xi_hr

    yr = yr[:, :L]
    yi = yi[:, :L]
    return torch.stack([yr, yi], dim=-1)

def batch_add_awgn(iq_b: torch.Tensor, snr_db: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)
    p = (sig.real**2 + sig.imag**2).mean(dim=1) + 1e-12
    npow = p / (10.0 ** (snr_db / 10.0))
    std = torch.sqrt(npow / 2.0).unsqueeze(1)
    noise = std * (torch.randn(B, L, device=iq_b.device) + 1j * torch.randn(B, L, device=iq_b.device))
    sig = sig + noise
    return _from_complex(sig)

def augment_train_batch(iq_b: torch.Tensor) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if AUG_USE_MULTIPATH:
        rms = (RMS_DS_NS_RANGE[1] - RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    if AUG_USE_DOPPLER:
        v = (TRAIN_V_KMH_RANGE[1] - TRAIN_V_KMH_RANGE[0]) * torch.rand(B, device=iq_b.device) + TRAIN_V_KMH_RANGE[0]
        iq_b = batch_apply_doppler(iq_b, v)

    if AUG_USE_AWGN:
        snr = (TRAIN_SNR_DB_RANGE[1] - TRAIN_SNR_DB_RANGE[0]) * torch.rand(B, device=iq_b.device) + TRAIN_SNR_DB_RANGE[0]
        iq_b = batch_add_awgn(iq_b, snr)

    return iq_b

def augment_test_batch(iq_b: torch.Tensor, snr_db: float) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if TEST_USE_MULTIPATH:
        rms = (TEST_RMS_DS_NS_RANGE[1] - TEST_RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + TEST_RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    v = torch.full((B,), float(TEST_V_KMH_FIXED), device=iq_b.device)
    iq_b = batch_apply_doppler(iq_b, v)

    snr = torch.full((B,), float(snr_db), device=iq_b.device)
    iq_b = batch_add_awgn(iq_b, snr)
    return iq_b


# ----------------------------
# 5) GPU batch STFT -> logmag -> resize
# ----------------------------
_WINDOW_CACHE = {}
def get_hann_window(device: torch.device):
    key = (device.type, device.index, SPEC_WIN)
    if key not in _WINDOW_CACHE:
        _WINDOW_CACHE[key] = torch.hann_window(SPEC_WIN, periodic=True, device=device)
    return _WINDOW_CACHE[key]

def iq_to_logspec_batch(iq_b: torch.Tensor) -> torch.Tensor:
    sig = _to_complex(iq_b).to(torch.complex64)  # (B,L)
    win = get_hann_window(iq_b.device)
    S = torch.stft(
        sig,
        n_fft=SPEC_NFFT,
        hop_length=SPEC_HOP,
        win_length=SPEC_WIN,
        window=win,
        center=True,
        return_complex=True
    )  # (B,F,T)

    mag = torch.abs(S) + 1e-12
    logmag = torch.log(mag)

    mu = logmag.mean(dim=(1,2), keepdim=True)
    sd = logmag.std(dim=(1,2), keepdim=True) + 1e-6
    logmag = (logmag - mu) / sd
    try:
        logmag = torch.nan_to_num(logmag, nan=0.0, posinf=0.0, neginf=0.0)
    except Exception:
        logmag[~torch.isfinite(logmag)] = 0.0

    x = logmag.unsqueeze(1)  # (B,1,F,T)
    x = F.interpolate(x, size=(SPEC_SIZE, SPEC_SIZE), mode="bilinear", align_corners=False)
    return x


# ----------------------------
# 6) 模型
# ----------------------------
class BasicBlock2D(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)

        self.down = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(identity)
        return self.relu(out + identity)

class SpecFeatureNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.relu  = nn.ReLU(inplace=True)

        self.b1 = BasicBlock2D(32, 32, stride=1)
        self.b2 = BasicBlock2D(32, 32, stride=1)
        self.b3 = BasicBlock2D(32, 64, stride=1)
        self.b4 = BasicBlock2D(64, 64, stride=1)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(64, 512)
        self.fc2 = nn.Linear(512, 256)   # z
        self.cls = nn.Linear(256, num_classes)

    def forward_once(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # (B,64)
        x = F.relu(self.fc1(x))
        z = self.fc2(x)
        logits = self.cls(z)
        return z, logits

    def forward(self, x1, x2=None):
        z1, p1 = self.forward_once(x1)
        if x2 is None:
            return z1, p1
        z2, p2 = self.forward_once(x2)
        return z1, p1, z2, p2


# ----------------------------
# 7) 损失与评估（NT-Xent 强制 fp32）
# ----------------------------
def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    with torch.cuda.amp.autocast(enabled=False):
        z1 = z1.float()
        z2 = z2.float()

        N = z1.size(0)
        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, dim=1)

        sim = (z @ z.T) / float(tau)
        sim.fill_diagonal_(torch.finfo(sim.dtype).min)

        pos = torch.arange(2 * N, device=z.device)
        pos = (pos + N) % (2 * N)

        log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
        loss = -log_prob[torch.arange(2 * N, device=z.device), pos]
        return loss.mean()

@torch.no_grad()
def eval_single_iq(model: SpecFeatureNet, loader: DataLoader, num_classes: int, mode: str, snr_db: float = None):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total, correct = 0, 0
    loss_sum, nb = 0.0, 0
    all_y, all_p = [], []

    for iq, y in loader:
        iq = iq.to(DEVICE, non_blocking=True)
        y  = y.to(DEVICE, non_blocking=True)

        if mode == "test":
            iq = augment_test_batch(iq, snr_db=float(snr_db))

        spec = iq_to_logspec_batch(iq)
        _, logits = model(spec, None)
        loss = ce(logits, y)

        loss_sum += loss.item()
        nb += 1

        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        if RETURN_CM:
            all_y.append(y.detach().cpu().numpy())
            all_p.append(pred.detach().cpu().numpy())

    acc = 100.0 * correct / max(total, 1)
    cm = None
    if RETURN_CM and total > 0:
        all_y = np.concatenate(all_y) if all_y else np.array([])
        all_p = np.concatenate(all_p) if all_p else np.array([])
        cm = confusion_matrix(all_y, all_p, labels=list(range(num_classes)))
    return (loss_sum / max(nb, 1)), acc, cm


# ----------------------------
# 8) KFold 训练 + 测试 SNR sweep（带日志保存）
# ----------------------------
def train_kfold_wisig_fast_with_logging(compact_dataset, tx_names):
    train_index, tx_i_to_label = build_index_list(compact_dataset, tx_names, train_dates, equalized, max_sig)
    test_index,  tx_i_to_label_test = build_index_list(compact_dataset, tx_names, test_dates,  equalized, max_sig)
    if tx_i_to_label != tx_i_to_label_test:
        raise RuntimeError("Train/Test TX label mapping mismatch.")

    num_classes = len(tx_i_to_label)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_dir = f"{timestamp}_{SCRIPT_NAME}"
    save_folder = os.path.join(SAVE_ROOT, save_dir)
    os.makedirs(save_folder, exist_ok=True)

    results_txt = os.path.join(save_folder, "results.txt")

    # config
    fd_test = (TEST_V_KMH_FIXED / 3.6) / 3e8 * FC
    with open(os.path.join(save_folder, "config.txt"), "w", encoding="utf-8") as f:
        f.write(f"DEVICE={DEVICE}\nAMP={USE_AMP}\n")
        f.write(f"dataset_path={dataset_path}\ndataset_name={dataset_name}\n")
        f.write(f"tx_names={tx_names}\n")
        f.write(f"train_dates={train_dates}\n")
        f.write(f"test_dates={test_dates}\n")
        f.write(f"equalized={equalized}, max_sig={max_sig}\n")
        f.write(f"num_classes={num_classes}\n")
        f.write(f"BATCH_SIZE={BATCH_SIZE}, LR={LR}, WEIGHT_DECAY={WEIGHT_DECAY}\n")
        f.write(f"MAX_EPOCHS={MAX_EPOCHS}, N_SPLITS={N_SPLITS}\n")
        f.write(f"LR_PATIENCE={LR_PATIENCE}, LR_FACTOR={LR_FACTOR}, ES_PATIENCE={ES_PATIENCE}\n")
        f.write(f"TAU={TAU}, LAMBDA_CL={LAMBDA_CL}, LAMBDA_CE={LAMBDA_CE}\n")
        f.write(f"FS={FS}, FC={FC}\n")
        f.write(f"AUG_USE_MULTIPATH={AUG_USE_MULTIPATH}, RMS_DS_NS_RANGE={RMS_DS_NS_RANGE}, MAX_TAPS={MAX_TAPS}\n")
        f.write(f"AUG_USE_DOPPLER={AUG_USE_DOPPLER}, TRAIN_V_KMH_RANGE={TRAIN_V_KMH_RANGE}\n")
        f.write(f"AUG_USE_AWGN={AUG_USE_AWGN}, TRAIN_SNR_DB_RANGE={TRAIN_SNR_DB_RANGE}\n")
        f.write(f"TEST_V_KMH_FIXED={TEST_V_KMH_FIXED}, fd_test={fd_test}\n")
        f.write(f"TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}, TEST_SNR_LIST={TEST_SNR_LIST}\n")
        f.write(f"SPEC_NFFT={SPEC_NFFT}, SPEC_WIN={SPEC_WIN}, SPEC_HOP={SPEC_HOP}, SPEC_SIZE={SPEC_SIZE}\n")
        f.write(f"workers(train/eval)={NUM_WORKERS_TRAIN}/{NUM_WORKERS_EVAL}\n")

    print(f"[INFO] Classes={num_classes} | TrainSamples={len(train_index)} | TestSamples={len(test_index)}")
    print(f"[INFO] SaveFolder: {save_folder}")

    write_line(results_txt, f"Classes={num_classes}, TrainSamples={len(train_index)}, TestSamples={len(test_index)}")
    write_line(results_txt, f"TrainDates={train_dates}, TestDates={test_dates}")
    write_line(results_txt, f"Train SNR~U{TRAIN_SNR_DB_RANGE}, v~U{TRAIN_V_KMH_RANGE}, multipath={AUG_USE_MULTIPATH}")
    write_line(results_txt, f"Test v={TEST_V_KMH_FIXED} (fd={fd_test:.2f}Hz), SNR sweep={TEST_SNR_LIST}, TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}")
    write_line(results_txt, f"DEVICE={DEVICE}, AMP={USE_AMP}")
    write_line(results_txt, "-"*80)

    # 复用同一个 test loader（不同 snr 在 eval 内做增强）
    test_ds = WiSigSingleIQDataset(compact_dataset, test_index, tx_i_to_label, equalized=equalized)
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=True
    )

    snr_to_accs = {snr: [] for snr in TEST_SNR_LIST}
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    indices = np.arange(len(train_index))
    kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

    for fold, (tr_idx, va_idx) in enumerate(kf.split(indices), 1):
        print(f"\n========== Fold {fold}/{N_SPLITS} ==========")
        write_line(results_txt, f"\n========== Fold {fold}/{N_SPLITS} ==========")

        tr_list = [train_index[i] for i in tr_idx]
        va_list = [train_index[i] for i in va_idx]

        tr_ds = WiSigSiameseIQDataset(compact_dataset, tr_list, tx_i_to_label, equalized=equalized, max_sig=max_sig)
        va_ds = WiSigSingleIQDataset(compact_dataset, va_list, tx_i_to_label, equalized=equalized)

        tr_loader = DataLoader(
            tr_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,
            num_workers=NUM_WORKERS_TRAIN, pin_memory=True
        )
        va_loader = DataLoader(
            va_ds, batch_size=BATCH_SIZE, shuffle=False,
            num_workers=NUM_WORKERS_EVAL, pin_memory=True
        )

        model = SpecFeatureNet(num_classes=num_classes).to(DEVICE)
        opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, mode="min", factor=LR_FACTOR, patience=LR_PATIENCE
        )
        ce = nn.CrossEntropyLoss()

        best_val_loss = float("inf")
        best_state = None
        es_count = 0
        best_epoch = 0

        # fold train log
        fold_log_path = os.path.join(save_folder, f"fold{fold}_trainlog.csv")
        fold_logger = CSVLogger(
            fold_log_path,
            header=["epoch", "lr", "train_loss", "val_loss", "val_acc", "best_val_loss", "es_count", "epoch_time_sec"]
        )

        for epoch in range(1, MAX_EPOCHS + 1):
            t0 = time.time()
            model.train()
            loss_sum, nb = 0.0, 0

            for iq1, iq2, y in tr_loader:
                iq1 = iq1.to(DEVICE, non_blocking=True)
                iq2 = iq2.to(DEVICE, non_blocking=True)
                y   = y.to(DEVICE, non_blocking=True)

                iq_cat = torch.cat([iq1, iq2], dim=0)    # (2B,L,2)
                iq_cat = augment_train_batch(iq_cat)     # (2B,L,2)
                spec_cat = iq_to_logspec_batch(iq_cat)   # (2B,1,S,S)
                spec1, spec2 = spec_cat.chunk(2, dim=0)

                opt.zero_grad(set_to_none=True)
                with torch.cuda.amp.autocast(enabled=USE_AMP):
                    z1, p1, z2, p2 = model(spec1, spec2)
                    loss_cl = nt_xent_loss(z1, z2, tau=TAU)     # fp32 safe
                    loss_ce = 0.5 * (ce(p1, y) + ce(p2, y))
                    loss = LAMBDA_CL * loss_cl + LAMBDA_CE * loss_ce

                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

                loss_sum += float(loss.item())
                nb += 1

            train_loss = loss_sum / max(nb, 1)
            val_loss, val_acc, _ = eval_single_iq(model, va_loader, num_classes=num_classes, mode="val")

            prev_lr = opt.param_groups[0]["lr"]
            scheduler.step(val_loss)
            cur_lr = opt.param_groups[0]["lr"]
            if cur_lr < prev_lr:
                msg = f"[LR DROP] {prev_lr:.2e} -> {cur_lr:.2e} (val_loss={val_loss:.4f})"
                print(msg)
                write_line(results_txt, msg)

            epoch_time = time.time() - t0
            msg = (f"Epoch {epoch:03d} | LR={cur_lr:.2e} | "
                   f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | ValAcc={val_acc:.2f}% | "
                   f"BestValLoss={best_val_loss:.4f} | ES={es_count}/{ES_PATIENCE}")
            print(msg)
            write_line(results_txt, msg)

            fold_logger.log_row([epoch, cur_lr, train_loss, val_loss, val_acc, best_val_loss, es_count, epoch_time])

            if val_loss < best_val_loss - 1e-6:
                best_val_loss = val_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                best_epoch = epoch
                es_count = 0
            else:
                es_count += 1
                if es_count >= ES_PATIENCE:
                    msg = "[INFO] Early stopping triggered."
                    print(msg)
                    write_line(results_txt, msg)
                    break

        # 保存 best 与 last
        torch.save(model.state_dict(), os.path.join(save_folder, f"model_fold{fold}.pth"))
        if best_state is not None:
            model.load_state_dict(best_state)
            torch.save(model.state_dict(), os.path.join(save_folder, f"best_model_fold{fold}.pth"))

        write_line(results_txt, f"[FOLD {fold}] BestEpoch={best_epoch}, BestValLoss={best_val_loss:.6f}")

        # fold test sweep（保存 fold{K}_test_snr.csv）
        fold_test_csv = os.path.join(save_folder, f"fold{fold}_test_snr.csv")
        fold_test_logger = CSVLogger(fold_test_csv, header=["snr_db", "test_loss", "test_acc"])

        fold_snr_acc = {}
        for snr in TEST_SNR_LIST:
            test_loss, test_acc, _ = eval_single_iq(model, test_loader, num_classes=num_classes, mode="test", snr_db=float(snr))
            snr_to_accs[snr].append(test_acc)
            fold_snr_acc[snr] = test_acc
            fold_test_logger.log_row([snr, test_loss, test_acc])

        msg = "[FOLD TEST] " + ", ".join([f"{snr}:{fold_snr_acc[snr]:.2f}%" for snr in TEST_SNR_LIST])
        print(msg)
        write_line(results_txt, msg)

    # 汇总 test sweep
    rows = []
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std())  if arr.size else 0.0
        rows.append([snr, mean, std] + snr_to_accs[snr])

    csv_path = os.path.join(save_folder, "test_snr_sweep.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        header = ["snr_db", "acc_mean", "acc_std"] + [f"fold{i}" for i in range(1, N_SPLITS+1)]
        writer.writerow(header)
        writer.writerows(rows)

    write_line(results_txt, "\n========== Overall Test SNR Sweep (mean±std over folds) ==========")
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std())  if arr.size else 0.0
        write_line(results_txt, f"SNR {snr:>3} dB | Acc {mean:.2f} ± {std:.2f}")

    print(f"\n[INFO] All saved in: {save_folder}")
    print(f"[INFO] SNR sweep CSV: {csv_path}")
    return save_folder


# ----------------------------
# 9) main
# ----------------------------
if __name__ == "__main__":
    compact_dataset = load_compact_pkl_dataset(dataset_path, dataset_name)

    print("数据集发射机数量：", len(compact_dataset["tx_list"]), "具体为：", compact_dataset["tx_list"])
    print("数据集接收机数量：", len(compact_dataset["rx_list"]), "具体为：", compact_dataset["rx_list"])
    print("数据集采集天数：", len(compact_dataset["capture_date_list"]), "具体为：", compact_dataset["capture_date_list"])

    tx_names = compact_dataset["tx_list"]  # 或指定 6TX 子集
    train_kfold_wisig_fast_with_logging(compact_dataset, tx_names)


In [6]:
# ============================================================
# LTE-V (.mat via HDF5) - Spec + SimCLR-style Siamese (FAST GPU PIPELINE + LOGGING)
#
# Key changes vs your previous LTEV SpecSiamese:
#   1) Split mode: FORCED SAMPLE-LEVEL stratified split over all (fi, si)
#   2) CV: StratifiedKFold on train sample indices
#   3) Positive pair (SimCLR): two random augmented views of THE SAME sample
#      (no cross-file / cross-sample positive pairing)
#
# Keeps:
#   - Contrastive learning (NT-Xent)
#   - Classification head CE term (can disable by setting LAMBDA_CE=0.0)
#   - GPU batch augmentation + GPU STFT + AMP
#   - SNR sweep evaluation with Doppler fixed at TEST_V_KMH_FIXED
#
# Saved files:
#   - config.txt
#   - results.txt
#   - fold{K}_trainlog.csv
#   - fold{K}_test_snr.csv
#   - test_snr_sweep.csv
#   - model_fold{K}.pth
#   - best_model_fold{K}.pth
# ============================================================

import os
import glob
import csv
import time
import random
import numpy as np
from datetime import datetime
from collections import Counter
from contextlib import nullcontext

import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import confusion_matrix

# ----------------------------
# 0) Global config
# ----------------------------
SEED = 42
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

USE_AMP = (DEVICE.type == "cuda")

DATA_PATH = "E:/rf_datasets_IQ_raw/"   # <-- change to your LTE-V folder
RECURSIVE_GLOB = True

FS = 5e6
FC = 5.9e9

BATCH_SIZE = 256
LR = 3e-4
WEIGHT_DECAY = 0.0
MAX_EPOCHS = 300

N_SPLITS = 5
TEST_SIZE = 0.25

LR_PATIENCE = 10
LR_FACTOR = 0.5
ES_PATIENCE = 30

# Contrastive loss
TAU = 0.05
LAMBDA_CL = 1.0

# CE head (set LAMBDA_CE=0.0 if you want strictly self-supervised training)
LAMBDA_CE = 1.0

# Augmentation switches
AUG_USE_MULTIPATH = True
AUG_USE_DOPPLER   = True
AUG_USE_AWGN      = True

RMS_DS_NS_RANGE = (5.0, 300.0)

# If you want Doppler fixed at 120 for train/val/test, set range to (120,120).
TRAIN_V_KMH_RANGE = (120.0, 120.0)
TRAIN_SNR_DB_RANGE = (-40.0, 20.0)

TEST_V_KMH_FIXED = 120.0
TEST_USE_MULTIPATH = False
TEST_RMS_DS_NS_RANGE = RMS_DS_NS_RANGE
TEST_SNR_LIST = list(range(20, -45, -5))

# Spectrogram
SPEC_NFFT = 128
SPEC_WIN  = 128
SPEC_HOP  = 16
SPEC_SIZE = 64

# Multipath taps
MAX_TAPS = 16

# Dataset cap (optional)
MAX_SAMPLES_PER_FILE = None

# Dataloader
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_EVAL  = 0

# Logging / saving
SAVE_ROOT = "./training_results"
os.makedirs(SAVE_ROOT, exist_ok=True)
SCRIPT_NAME = "LTEV_SpecSimCLR_FASTGPU_SNRsweep_Doppler120_SampleSplit"

RETURN_CM = False

# ----------------------------
# AMP compatibility (robust across torch versions)
# ----------------------------
from contextlib import nullcontext

def amp_autocast(enabled: bool = True):
    """
    Robust autocast wrapper:
      - torch.amp.autocast may accept (device_type=...) or (device, ...)
      - fallback to torch.cuda.amp.autocast on older versions
    """
    if not enabled:
        return nullcontext()

    # Prefer torch.amp.autocast if available
    if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"):
        try:
            # Newer signature: autocast(device_type="cuda", ...)
            return torch.amp.autocast(device_type=DEVICE.type, enabled=True)
        except TypeError:
            # Older signature: autocast("cuda", ...)
            return torch.amp.autocast(DEVICE.type, enabled=True)

    # Fallback
    if DEVICE.type == "cuda":
        return torch.cuda.amp.autocast(enabled=True)

    return nullcontext()

def make_grad_scaler(enabled: bool = True):
    """
    Robust GradScaler factory:
      - torch.amp.GradScaler may accept (device_type=...), or ("cuda", ...), or only (enabled=...)
      - fallback to torch.cuda.amp.GradScaler
    """
    if (not enabled) or (DEVICE.type != "cuda"):
        return None

    if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
        # Try multiple ctor signatures
        for ctor in (
            lambda: torch.amp.GradScaler(device_type="cuda", enabled=True),  # some versions
            lambda: torch.amp.GradScaler("cuda", enabled=True),              # recommended warning style
            lambda: torch.amp.GradScaler(enabled=True),                      # minimal
        ):
            try:
                return ctor()
            except TypeError:
                pass

    return torch.cuda.amp.GradScaler(enabled=True)


# ----------------------------
# 1) Logging helpers
# ----------------------------
class CSVLogger:
    def __init__(self, path, header):
        self.path = path
        self.header = header
        self._init_file()

    def _init_file(self):
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        with open(self.path, "w", newline="", encoding="utf-8") as f:
            csv.writer(f).writerow(self.header)

    def log_row(self, row):
        with open(self.path, "a", newline="", encoding="utf-8") as f:
            csv.writer(f).writerow(row)

def write_line(path, line):
    with open(path, "a", encoding="utf-8") as f:
        f.write(line.rstrip() + "\n")

# ----------------------------
# 2) LTE-V data reading
# ----------------------------
def decode_txid(txid_arr: np.ndarray) -> str:
    txid_arr = txid_arr.flatten()
    chars = []
    for c in txid_arr:
        if int(c) == 0:
            continue
        try:
            chars.append(chr(int(c)))
        except Exception:
            pass
    s = "".join(chars).strip()
    return s if s else "UNKNOWN"

def load_dmrs_complex_from_file(h5f: h5py.File) -> np.ndarray:
    rfDataset = h5f["rfDataset"]
    dmrs_obj = rfDataset["dmrs"]

    if isinstance(dmrs_obj, h5py.Dataset) and dmrs_obj.dtype.fields is not None:
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    elif isinstance(dmrs_obj, h5py.Group):
        real = dmrs_obj["real"][()]
        imag = dmrs_obj["imag"][()]
    else:
        tmp = dmrs_obj[()]
        if hasattr(tmp, "dtype") and tmp.dtype.fields is not None:
            real = tmp["real"]
            imag = tmp["imag"]
        else:
            raise RuntimeError("Cannot parse dmrs (not compound dataset nor group real/imag).")

    dmrs_complex = np.asarray(real + 1j * imag)
    if dmrs_complex.ndim != 2:
        raise RuntimeError(f"dmrs_complex dim error: {dmrs_complex.shape}")

    if dmrs_complex.shape[0] <= 2048 and dmrs_complex.shape[1] > dmrs_complex.shape[0]:
        dmrs_complex = dmrs_complex.T

    return dmrs_complex.astype(np.complex64)

def load_ltev_dataset(data_path: str, recursive: bool = True, max_samples_per_file=None):
    if recursive:
        mat_files = glob.glob(os.path.join(data_path, "**", "*.mat"), recursive=True)
    else:
        mat_files = glob.glob(os.path.join(data_path, "*.mat"))

    if len(mat_files) == 0:
        raise RuntimeError(f"No .mat files found: {data_path}")

    signals_by_file = []
    label_str_by_file = []
    file_paths = []

    print(f"[INFO] Found {len(mat_files)} .mat files")
    for fp in mat_files:
        try:
            with h5py.File(fp, "r") as f:
                rfDataset = f["rfDataset"]
                txid_arr = np.asarray(rfDataset["txID"][()])
                tx_str = decode_txid(txid_arr)

                dmrs_complex = load_dmrs_complex_from_file(f)
                if max_samples_per_file is not None:
                    dmrs_complex = dmrs_complex[: int(max_samples_per_file)]

                if dmrs_complex.shape[0] == 0:
                    continue

                signals_by_file.append(dmrs_complex)
                label_str_by_file.append(tx_str)
                file_paths.append(fp)
        except Exception as e:
            print(f"[WARN] Skip file due to read/parse error: {fp} | {repr(e)}")

    if len(signals_by_file) == 0:
        raise RuntimeError("All files failed to read or are empty.")

    cnt = Counter(label_str_by_file)
    print("[INFO] txID classes:", len(cnt))
    for k, v in sorted(cnt.items(), key=lambda x: (-x[1], x[0])):
        print(f"  {k}: {v} files")

    Ls = [arr.shape[1] for arr in signals_by_file]
    mode_L = Counter(Ls).most_common(1)[0][0]
    print(f"[INFO] DMRS length stats: min={min(Ls)}, max={max(Ls)}, mode={mode_L}")

    # unify length to mode_L
    for i in range(len(signals_by_file)):
        x = signals_by_file[i]
        if x.shape[1] == mode_L:
            continue
        if x.shape[1] > mode_L:
            signals_by_file[i] = x[:, :mode_L].astype(np.complex64)
        else:
            pad = mode_L - x.shape[1]
            signals_by_file[i] = np.pad(x, ((0, 0), (0, pad)), mode="constant").astype(np.complex64)

    return signals_by_file, label_str_by_file, file_paths, mode_L

# ----------------------------
# 3) Sample-level split + StratifiedKFold
# ----------------------------
def build_all_sample_list(signals_by_file):
    sample_list = []
    for fi, arr in enumerate(signals_by_file):
        n = arr.shape[0]
        sample_list.extend([(fi, si) for si in range(n)])
    return sample_list

def split_train_test_by_samples(signals_by_file, label_idx_by_file, test_size=TEST_SIZE):
    all_samples = build_all_sample_list(signals_by_file)
    y_samples = np.array([label_idx_by_file[fi] for (fi, _) in all_samples], dtype=np.int64)

    tr_i, te_i = train_test_split(
        np.arange(len(all_samples)),
        test_size=test_size,
        stratify=y_samples,
        random_state=SEED
    )
    train_samples = [all_samples[i] for i in tr_i]
    test_samples  = [all_samples[i] for i in te_i]

    print("[INFO] Split mode: SAMPLE-LEVEL (force)")
    print(f"[INFO] Train samples={len(train_samples)}, Test samples={len(test_samples)}")

    c_tr = Counter([label_idx_by_file[fi] for (fi, _) in train_samples])
    c_te = Counter([label_idx_by_file[fi] for (fi, _) in test_samples])
    print("[INFO] Train class sample counts:", dict(sorted(c_tr.items())))
    print("[INFO] Test  class sample counts:", dict(sorted(c_te.items())))

    return train_samples, test_samples

# ----------------------------
# 4) Datasets
# ----------------------------
class LTEVSimCLRIQDataset(Dataset):
    """
    SimCLR positive pair is generated in the training loop:
    two random augmented views of the SAME iq.
    Dataset returns (iq, label) for optional CE head & evaluation.
    """
    def __init__(self, signals_by_file, label_idx_by_file, sample_list):
        self.signals_by_file = signals_by_file
        self.label_idx_by_file = label_idx_by_file
        self.sample_list = sample_list

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        fi, si = self.sample_list[idx]
        lab = int(self.label_idx_by_file[fi])
        sig = self.signals_by_file[fi][si]  # complex (L,)
        iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)  # (L,2)
        return iq, np.int64(lab)

class LTEVSingleIQDataset(Dataset):
    def __init__(self, signals_by_file, label_idx_by_file, sample_list):
        self.signals_by_file = signals_by_file
        self.label_idx_by_file = label_idx_by_file
        self.sample_list = sample_list

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        fi, si = self.sample_list[idx]
        lab = int(self.label_idx_by_file[fi])
        sig = self.signals_by_file[fi][si]
        iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)
        return iq, np.int64(lab)

# ----------------------------
# 5) GPU batch augmentation
# ----------------------------
def _to_complex(iq_b: torch.Tensor) -> torch.Tensor:
    return iq_b[..., 0].to(torch.float32) + 1j * iq_b[..., 1].to(torch.float32)

def _from_complex(sig: torch.Tensor) -> torch.Tensor:
    return torch.stack([sig.real, sig.imag], dim=-1)

def batch_normalize_power(iq_b: torch.Tensor) -> torch.Tensor:
    power = (iq_b[..., 0] ** 2 + iq_b[..., 1] ** 2).mean(dim=1, keepdim=True) + 1e-12
    scale = torch.rsqrt(power).unsqueeze(-1)
    return iq_b * scale

def batch_apply_doppler(iq_b: torch.Tensor, v_kmh: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)

    c = 3e8
    v = v_kmh / 3.6
    fd = (v / c) * FC  # (B,)
    n = torch.arange(L, device=iq_b.device, dtype=torch.float32).unsqueeze(0)  # (1,L)
    phase = torch.exp(1j * 2.0 * np.pi * fd.unsqueeze(1).to(torch.float32) * n / FS)  # (B,L)
    sig = sig * phase
    return _from_complex(sig)

def _grouped_conv1d_real(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    # x: (B,1,L), w: (B,1,K)
    B, _, L = x.shape
    x2 = x.permute(1, 0, 2).contiguous()  # (1,B,L)
    y2 = F.conv1d(x2, w, padding=w.shape[-1] - 1, groups=B)  # (1,B,L+K-1)
    return y2.squeeze(0)  # (B,L+K-1)

def batch_apply_multipath(iq_b: torch.Tensor, rms_ns: torch.Tensor, max_taps: int = MAX_TAPS) -> torch.Tensor:
    B, L, _ = iq_b.shape
    device = iq_b.device

    rms_s = rms_ns * 1e-9
    rms_samples = (rms_s * FS).clamp(min=1e-3)

    k = torch.arange(max_taps, device=device, dtype=torch.float32).unsqueeze(0)
    p = torch.exp(-k / rms_samples.unsqueeze(1))
    p = p / (p.sum(dim=1, keepdim=True) + 1e-12)

    hr = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)
    hi = torch.randn(B, max_taps, device=device) * torch.sqrt(p / 2.0)

    hpow = (hr**2 + hi**2).sum(dim=1, keepdim=True) + 1e-12
    norm = torch.rsqrt(hpow)
    hr = hr * norm
    hi = hi * norm

    xr = iq_b[..., 0]
    xi = iq_b[..., 1]

    xr_ = xr.unsqueeze(1)
    xi_ = xi.unsqueeze(1)
    hr_ = hr.unsqueeze(1)
    hi_ = hi.unsqueeze(1)

    xr_hr = _grouped_conv1d_real(xr_, hr_)
    xi_hi = _grouped_conv1d_real(xi_, hi_)
    xr_hi = _grouped_conv1d_real(xr_, hi_)
    xi_hr = _grouped_conv1d_real(xi_, hr_)

    yr = xr_hr - xi_hi
    yi = xr_hi + xi_hr

    yr = yr[:, :L]
    yi = yi[:, :L]
    return torch.stack([yr, yi], dim=-1)

def batch_add_awgn(iq_b: torch.Tensor, snr_db: torch.Tensor) -> torch.Tensor:
    B, L, _ = iq_b.shape
    sig = _to_complex(iq_b)
    p = (sig.real**2 + sig.imag**2).mean(dim=1) + 1e-12
    npow = p / (10.0 ** (snr_db / 10.0))
    std = torch.sqrt(npow / 2.0).unsqueeze(1)
    noise = std * (torch.randn(B, L, device=iq_b.device) + 1j * torch.randn(B, L, device=iq_b.device))
    sig = sig + noise
    return _from_complex(sig)

def augment_train_batch(iq_b: torch.Tensor) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if AUG_USE_MULTIPATH:
        rms = (RMS_DS_NS_RANGE[1] - RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    if AUG_USE_DOPPLER:
        vmin, vmax = TRAIN_V_KMH_RANGE
        if abs(vmax - vmin) < 1e-12:
            v = torch.full((B,), float(vmin), device=iq_b.device)
        else:
            v = (vmax - vmin) * torch.rand(B, device=iq_b.device) + vmin
        iq_b = batch_apply_doppler(iq_b, v)

    if AUG_USE_AWGN:
        smin, smax = TRAIN_SNR_DB_RANGE
        snr = (smax - smin) * torch.rand(B, device=iq_b.device) + smin
        iq_b = batch_add_awgn(iq_b, snr)

    return iq_b

def augment_test_batch(iq_b: torch.Tensor, snr_db: float) -> torch.Tensor:
    iq_b = batch_normalize_power(iq_b)
    B = iq_b.shape[0]

    if TEST_USE_MULTIPATH:
        rms = (TEST_RMS_DS_NS_RANGE[1] - TEST_RMS_DS_NS_RANGE[0]) * torch.rand(B, device=iq_b.device) + TEST_RMS_DS_NS_RANGE[0]
        iq_b = batch_apply_multipath(iq_b, rms, max_taps=MAX_TAPS)

    v = torch.full((B,), float(TEST_V_KMH_FIXED), device=iq_b.device)
    iq_b = batch_apply_doppler(iq_b, v)

    snr = torch.full((B,), float(snr_db), device=iq_b.device)
    iq_b = batch_add_awgn(iq_b, snr)
    return iq_b

# ----------------------------
# 6) GPU batch STFT
# ----------------------------
_WINDOW_CACHE = {}
def get_hann_window(device: torch.device):
    key = (device.type, device.index, SPEC_WIN)
    if key not in _WINDOW_CACHE:
        _WINDOW_CACHE[key] = torch.hann_window(SPEC_WIN, periodic=True, device=device)
    return _WINDOW_CACHE[key]

def iq_to_logspec_batch(iq_b: torch.Tensor) -> torch.Tensor:
    sig = _to_complex(iq_b).to(torch.complex64)  # (B,L)
    win = get_hann_window(iq_b.device)
    S = torch.stft(
        sig,
        n_fft=SPEC_NFFT,
        hop_length=SPEC_HOP,
        win_length=SPEC_WIN,
        window=win,
        center=True,
        return_complex=True
    )  # (B,F,T)

    mag = torch.abs(S) + 1e-12
    logmag = torch.log(mag)

    # per-sample z-score
    mu = logmag.mean(dim=(1, 2), keepdim=True)
    sd = logmag.std(dim=(1, 2), keepdim=True) + 1e-6
    logmag = (logmag - mu) / sd
    try:
        logmag = torch.nan_to_num(logmag, nan=0.0, posinf=0.0, neginf=0.0)
    except Exception:
        logmag[~torch.isfinite(logmag)] = 0.0

    x = logmag.unsqueeze(1)  # (B,1,F,T)
    x = F.interpolate(x, size=(SPEC_SIZE, SPEC_SIZE), mode="bilinear", align_corners=False)
    return x

# ----------------------------
# 7) Model
# ----------------------------
class BasicBlock2D(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)

        self.down = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(identity)
        return self.relu(out + identity)

class SpecFeatureNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.relu  = nn.ReLU(inplace=True)

        self.b1 = BasicBlock2D(32, 32, stride=1)
        self.b2 = BasicBlock2D(32, 32, stride=1)
        self.b3 = BasicBlock2D(32, 64, stride=1)
        self.b4 = BasicBlock2D(64, 64, stride=1)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(64, 512)
        self.fc2 = nn.Linear(512, 256)      # embedding z
        self.cls = nn.Linear(256, num_classes)

    def forward_once(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # (B,64)
        x = F.relu(self.fc1(x))
        z = self.fc2(x)
        logits = self.cls(z)
        return z, logits

    def forward(self, x1, x2=None):
        z1, p1 = self.forward_once(x1)
        if x2 is None:
            return z1, p1
        z2, p2 = self.forward_once(x2)
        return z1, p1, z2, p2

# ----------------------------
# 8) Loss + Eval
# ----------------------------
def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, tau: float = TAU) -> torch.Tensor:
    # force fp32 for numerical safety
    with amp_autocast(enabled=False):
        z1 = z1.float()
        z2 = z2.float()

        N = z1.size(0)
        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, dim=1)

        sim = (z @ z.T) / float(tau)
        sim.fill_diagonal_(torch.finfo(sim.dtype).min)

        pos = torch.arange(2 * N, device=z.device)
        pos = (pos + N) % (2 * N)

        log_prob = sim - torch.logsumexp(sim, dim=1, keepdim=True)
        loss = -log_prob[torch.arange(2 * N, device=z.device), pos]
        return loss.mean()

@torch.no_grad()
def eval_single_iq(model: SpecFeatureNet, loader: DataLoader, num_classes: int, mode: str, snr_db: float = None):
    model.eval()
    ce = nn.CrossEntropyLoss()

    total, correct = 0, 0
    loss_sum, nb = 0.0, 0
    all_y, all_p = [], []

    for iq, y in loader:
        iq = iq.to(DEVICE, non_blocking=True)
        y  = y.to(DEVICE, non_blocking=True)

        if mode == "test":
            iq = augment_test_batch(iq, snr_db=float(snr_db))

        spec = iq_to_logspec_batch(iq)
        _, logits = model(spec, None)
        loss = ce(logits, y)

        loss_sum += float(loss.item())
        nb += 1

        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        if RETURN_CM:
            all_y.append(y.detach().cpu().numpy())
            all_p.append(pred.detach().cpu().numpy())

    acc = 100.0 * correct / max(total, 1)
    cm = None
    if RETURN_CM and total > 0:
        all_y = np.concatenate(all_y) if all_y else np.array([])
        all_p = np.concatenate(all_p) if all_p else np.array([])
        cm = confusion_matrix(all_y, all_p, labels=list(range(num_classes)))
    return (loss_sum / max(nb, 1)), acc, cm

# ----------------------------
# 9) Training (SimCLR positive pairs) + SNR sweep
# ----------------------------
def train_kfold_ltev_simclr_sample_split(data_path: str):
    signals_by_file, label_str_by_file, file_paths, L = load_ltev_dataset(
        data_path, recursive=RECURSIVE_GLOB, max_samples_per_file=MAX_SAMPLES_PER_FILE
    )

    label_list = sorted(list(set(label_str_by_file)))
    label_to_idx = {lab: i for i, lab in enumerate(label_list)}
    label_idx_by_file = [label_to_idx[s] for s in label_str_by_file]
    num_classes = len(label_list)

    # Forced sample-level split
    train_samples, test_samples = split_train_test_by_samples(
        signals_by_file, label_idx_by_file, test_size=TEST_SIZE
    )

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_dir = f"{timestamp}_{SCRIPT_NAME}"
    save_folder = os.path.join(SAVE_ROOT, save_dir)
    os.makedirs(save_folder, exist_ok=True)

    results_txt = os.path.join(save_folder, "results.txt")
    fd_test = (TEST_V_KMH_FIXED / 3.6) / 3e8 * FC

    # config
    with open(os.path.join(save_folder, "config.txt"), "w", encoding="utf-8") as f:
        f.write(f"DEVICE={DEVICE}\nAMP={USE_AMP}\n")
        f.write(f"DATA_PATH={data_path}\nRECURSIVE_GLOB={RECURSIVE_GLOB}\n")
        f.write(f"MAX_SAMPLES_PER_FILE={MAX_SAMPLES_PER_FILE}\n")
        f.write(f"DMRS_LEN(L)={L}\n")
        f.write(f"num_classes={num_classes}\n")
        f.write(f"split_mode=sample_level_force\n")
        f.write(f"TEST_SIZE={TEST_SIZE}\n")
        f.write(f"BATCH_SIZE={BATCH_SIZE}, LR={LR}, WEIGHT_DECAY={WEIGHT_DECAY}\n")
        f.write(f"MAX_EPOCHS={MAX_EPOCHS}, N_SPLITS={N_SPLITS}\n")
        f.write(f"LR_PATIENCE={LR_PATIENCE}, LR_FACTOR={LR_FACTOR}, ES_PATIENCE={ES_PATIENCE}\n")
        f.write(f"TAU={TAU}, LAMBDA_CL={LAMBDA_CL}, LAMBDA_CE={LAMBDA_CE}\n")
        f.write(f"FS={FS}, FC={FC}\n")
        f.write(f"AUG_USE_MULTIPATH={AUG_USE_MULTIPATH}, RMS_DS_NS_RANGE={RMS_DS_NS_RANGE}, MAX_TAPS={MAX_TAPS}\n")
        f.write(f"AUG_USE_DOPPLER={AUG_USE_DOPPLER}, TRAIN_V_KMH_RANGE={TRAIN_V_KMH_RANGE}\n")
        f.write(f"AUG_USE_AWGN={AUG_USE_AWGN}, TRAIN_SNR_DB_RANGE={TRAIN_SNR_DB_RANGE}\n")
        f.write(f"TEST_V_KMH_FIXED={TEST_V_KMH_FIXED}, fd_test={fd_test}\n")
        f.write(f"TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}, TEST_SNR_LIST={TEST_SNR_LIST}\n")
        f.write(f"SPEC_NFFT={SPEC_NFFT}, SPEC_WIN={SPEC_WIN}, SPEC_HOP={SPEC_HOP}, SPEC_SIZE={SPEC_SIZE}\n")
        f.write(f"workers(train/eval)={NUM_WORKERS_TRAIN}/{NUM_WORKERS_EVAL}\n")
        f.write(f"CV=StratifiedKFold(on train sample indices)\n")
        f.write(f"POS_PAIR=SimCLR(two random augmentations of same sample)\n")

    print(f"[INFO] DEVICE={DEVICE} | AMP={USE_AMP}")
    print(f"[INFO] Classes={num_classes}, TrainSamples={len(train_samples)}, TestSamples={len(test_samples)}, L={L}")
    print(f"[INFO] SaveFolder: {save_folder}")

    write_line(results_txt, f"DEVICE={DEVICE}, AMP={USE_AMP}")
    write_line(results_txt, f"Classes={num_classes}, TrainSamples={len(train_samples)}, TestSamples={len(test_samples)}, L={L}")
    write_line(results_txt, f"split_mode=sample_level_force | TEST_SIZE={TEST_SIZE}")
    write_line(results_txt, f"POS_PAIR=SimCLR(two random augmentations of same sample)")
    write_line(results_txt, f"Train SNR~U{TRAIN_SNR_DB_RANGE}, v~{TRAIN_V_KMH_RANGE}, multipath={AUG_USE_MULTIPATH}")
    write_line(results_txt, f"Test v={TEST_V_KMH_FIXED} (fd={fd_test:.2f}Hz), SNR sweep={TEST_SNR_LIST}, TEST_USE_MULTIPATH={TEST_USE_MULTIPATH}")
    write_line(results_txt, "-" * 80)

    snr_to_accs = {snr: [] for snr in TEST_SNR_LIST}
    scaler = make_grad_scaler(enabled=USE_AMP)

    # fixed test loader (eval adds test-time augment)
    test_ds = LTEVSingleIQDataset(signals_by_file, label_idx_by_file, test_samples)
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=True
    )

    # StratifiedKFold on train samples
    y_train = np.array([label_idx_by_file[fi] for (fi, _) in train_samples], dtype=np.int64)
    idx_all = np.arange(len(train_samples), dtype=np.int64)
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

    for fold, (tr_idx, va_idx) in enumerate(skf.split(idx_all, y_train), 1):
        tr_samples_fold = [train_samples[i] for i in tr_idx]
        va_samples_fold = [train_samples[i] for i in va_idx]

        c_tr = Counter([label_idx_by_file[fi] for (fi, _) in tr_samples_fold])
        c_va = Counter([label_idx_by_file[fi] for (fi, _) in va_samples_fold])
        print(f"\n========== Fold {fold}/{N_SPLITS} ==========")
        print(f"[FOLD {fold}] train_sample_counts:", dict(sorted(c_tr.items())))
        print(f"[FOLD {fold}] val_sample_counts  :", dict(sorted(c_va.items())))

        _run_one_fold_ltev_simclr(
            fold, save_folder, results_txt, num_classes,
            signals_by_file, label_idx_by_file,
            tr_samples_fold, va_samples_fold,
            test_loader, snr_to_accs, scaler
        )

    # summarize sweep
    rows = []
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std()) if arr.size else 0.0
        rows.append([snr, mean, std] + snr_to_accs[snr])

    csv_path = os.path.join(save_folder, "test_snr_sweep.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        header = ["snr_db", "acc_mean", "acc_std"] + [f"fold{i}" for i in range(1, N_SPLITS + 1)]
        writer.writerow(header)
        writer.writerows(rows)

    write_line(results_txt, "\n========== Overall Test SNR Sweep (mean±std over folds) ==========")
    for snr in TEST_SNR_LIST:
        arr = np.array(snr_to_accs[snr], dtype=np.float64)
        mean = float(arr.mean()) if arr.size else 0.0
        std  = float(arr.std()) if arr.size else 0.0
        write_line(results_txt, f"SNR {snr:>3} dB | Acc {mean:.2f} ± {std:.2f}")

    print(f"\n[INFO] All saved in: {save_folder}")
    print(f"[INFO] SNR sweep CSV: {csv_path}")
    return save_folder

def _run_one_fold_ltev_simclr(
    fold: int,
    save_folder: str,
    results_txt: str,
    num_classes: int,
    signals_by_file,
    label_idx_by_file,
    tr_samples_fold,
    va_samples_fold,
    test_loader,
    snr_to_accs: dict,
    scaler
):
    write_line(results_txt, f"\n========== Fold {fold}/{N_SPLITS} ==========")

    tr_ds = LTEVSimCLRIQDataset(signals_by_file, label_idx_by_file, tr_samples_fold)
    va_ds = LTEVSingleIQDataset(signals_by_file, label_idx_by_file, va_samples_fold)

    tr_loader = DataLoader(
        tr_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,
        num_workers=NUM_WORKERS_TRAIN, pin_memory=True
    )
    va_loader = DataLoader(
        va_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS_EVAL, pin_memory=True
    )

    model = SpecFeatureNet(num_classes=num_classes).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", factor=LR_FACTOR, patience=LR_PATIENCE
    )
    ce = nn.CrossEntropyLoss()

    best_val_loss = float("inf")
    best_state = None
    es_count = 0
    best_epoch = 0

    fold_log_path = os.path.join(save_folder, f"fold{fold}_trainlog.csv")
    fold_logger = CSVLogger(
        fold_log_path,
        header=["epoch", "lr", "train_loss", "val_loss", "val_acc", "best_val_loss", "es_count", "epoch_time_sec"]
    )

    for epoch in range(1, MAX_EPOCHS + 1):
        t0 = time.time()
        model.train()
        loss_sum, nb = 0.0, 0

        for iq, y in tr_loader:
            iq = iq.to(DEVICE, non_blocking=True)  # (B,L,2)
            y  = y.to(DEVICE, non_blocking=True)

            # SimCLR positive pair: two random views of same sample
            v1 = augment_train_batch(iq)
            v2 = augment_train_batch(iq)

            spec1 = iq_to_logspec_batch(v1)
            spec2 = iq_to_logspec_batch(v2)

            opt.zero_grad(set_to_none=True)

            with amp_autocast(enabled=USE_AMP):
                z1, p1, z2, p2 = model(spec1, spec2)
                loss_cl = nt_xent_loss(z1, z2, tau=TAU)

                if LAMBDA_CE > 0.0:
                    loss_ce = 0.5 * (ce(p1, y) + ce(p2, y))
                else:
                    loss_ce = torch.zeros((), device=DEVICE)

                loss = LAMBDA_CL * loss_cl + LAMBDA_CE * loss_ce

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                opt.step()

            loss_sum += float(loss.item())
            nb += 1

        train_loss = loss_sum / max(nb, 1)
        val_loss, val_acc, _ = eval_single_iq(model, va_loader, num_classes=num_classes, mode="val")

        prev_lr = opt.param_groups[0]["lr"]
        scheduler.step(val_loss)
        cur_lr = opt.param_groups[0]["lr"]
        if cur_lr < prev_lr:
            msg = f"[LR DROP] {prev_lr:.2e} -> {cur_lr:.2e} (val_loss={val_loss:.4f})"
            print(msg)
            write_line(results_txt, msg)

        epoch_time = time.time() - t0
        msg = (f"Epoch {epoch:03d} | LR={cur_lr:.2e} | "
               f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | ValAcc={val_acc:.2f}% | "
               f"BestValLoss={best_val_loss:.4f} | ES={es_count}/{ES_PATIENCE}")
        print(msg)
        write_line(results_txt, msg)

        fold_logger.log_row([epoch, cur_lr, train_loss, val_loss, val_acc, best_val_loss, es_count, epoch_time])

        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_epoch = epoch
            es_count = 0
        else:
            es_count += 1
            if es_count >= ES_PATIENCE:
                msg = "[INFO] Early stopping triggered."
                print(msg)
                write_line(results_txt, msg)
                break

    torch.save(model.state_dict(), os.path.join(save_folder, f"model_fold{fold}.pth"))
    if best_state is not None:
        model.load_state_dict(best_state)
        torch.save(model.state_dict(), os.path.join(save_folder, f"best_model_fold{fold}.pth"))

    write_line(results_txt, f"[FOLD {fold}] BestEpoch={best_epoch}, BestValLoss={best_val_loss:.6f}")

    # Test sweep for this fold
    fold_test_csv = os.path.join(save_folder, f"fold{fold}_test_snr.csv")
    fold_test_logger = CSVLogger(fold_test_csv, header=["snr_db", "test_loss", "test_acc"])

    fold_snr_acc = {}
    for snr in TEST_SNR_LIST:
        test_loss, test_acc, _ = eval_single_iq(model, test_loader, num_classes=num_classes, mode="test", snr_db=float(snr))
        snr_to_accs[snr].append(test_acc)
        fold_snr_acc[snr] = test_acc
        fold_test_logger.log_row([snr, test_loss, test_acc])

    msg = "[FOLD TEST] " + ", ".join([f"{snr}:{fold_snr_acc[snr]:.2f}%" for snr in TEST_SNR_LIST])
    print(msg)
    write_line(results_txt, msg)

# ----------------------------
# 10) main
# ----------------------------
if __name__ == "__main__":
    train_kfold_ltev_simclr_sample_split(DATA_PATH)


[INFO] Found 72 .mat files
[INFO] txID classes: 9
  001: 8 files
  002: 8 files
  003: 8 files
  004: 8 files
  005: 8 files
  006: 8 files
  007: 8 files
  008: 8 files
  009: 8 files
[INFO] DMRS length stats: min=256, max=256, mode=256
[INFO] Split mode: SAMPLE-LEVEL (force)
[INFO] Train samples=158950, Test samples=52984
[INFO] Train class sample counts: {0: 17698, 1: 17518, 2: 17833, 3: 17671, 4: 17600, 5: 17882, 6: 17465, 7: 17582, 8: 17701}
[INFO] Test  class sample counts: {0: 5899, 1: 5839, 2: 5944, 3: 5891, 4: 5867, 5: 5961, 6: 5822, 7: 5860, 8: 5901}
[INFO] DEVICE=cuda | AMP=True
[INFO] Classes=9, TrainSamples=158950, TestSamples=52984, L=256
[INFO] SaveFolder: ./training_results\2026-01-25_01-46-02_LTEV_SpecSimCLR_FASTGPU_SNRsweep_Doppler120_SampleSplit

[FOLD 1] train_sample_counts: {0: 14158, 1: 14015, 2: 14266, 3: 14137, 4: 14080, 5: 14305, 6: 13972, 7: 14066, 8: 14161}
[FOLD 1] val_sample_counts  : {0: 3540, 1: 3503, 2: 3567, 3: 3534, 4: 3520, 5: 3577, 6: 3493, 7: 3516, 