In [1]:
# ============================================================
# LTE-V  多域训练 + 单域评估（RAW vs XFR） + SNR Sweep
#
# 训练（multi-domain）：
#   - domain = (rx, v)
#   - 对每个 TX 类别：从所有可用 domain 中“顺序轮询”抽样，拼成 block_size 个 frame
#   - 对 block 做 XFR 转置： (block_size, sample_len, 2) -> (sample_len, block_size, 2)
#   - 展开为 row-samples：每个 row 是一个样本，shape=(block_size,2)
#
# 验证/测试（single-domain, per (rx,v)）两套版本：
#   1) RAW：原始 frame 作为样本，shape=(sample_len,2)
#   2) XFR：在该单域内部拼 block -> 转置 -> 展开 row-samples，shape=(block_size,2)
#
# SNR sweep:
#   - 从“干净（仅功率归一化）”信号缓存出发，对每个 SNR 重新加 AWGN（可选额外 Doppler）
#   - 每个 SNR 独立训练并保存结果与曲线
#
# 你需要保证数据路径下能解析出 rx 与 v（速度）：
#   默认解析策略：从路径组件里找 rx\d+ / v\d+ / \d+kmh 等。
#   若你数据命名不同，修改 parse_domain_from_path() 即可。
# ============================================================

import os
import re
import glob
import gc
import math
import json
import numpy as np
import h5py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.metrics import confusion_matrix
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from datetime import datetime
from tqdm import tqdm


# =========================
# 全局参数（按需改）
# =========================
DATA_ROOT = r"E:/rf_datasets_IQ_raw/"          # LTE-V 根目录（含 .mat）
SAVE_ROOT = "./training_results"

# SNR sweep
SNR_LIST = list(range(20, -45, -5))           # 20,15,...,-40
APPLY_AWGN = True
APPLY_EXTRA_DOPPLER = True                   # 通常 LTE-V 本身已含速度效应；默认 False
FS = 5e6
FC = 5.9e9

# 多域/单域 XFR 参数
BLOCK_SIZE = 256                              # XFR 的 block_size（每个 row sample 的长度）
Y_TAKE = 1                                    # 轮询时每个 domain 一次取 y 条 frame
MAX_SIG_PER_DOMAIN_LABEL = None               # 可限制每个(domain,label)最多用多少帧（None=不限制）

# 数据集划分（按每个 domain & label 的时序切分）
TRAIN_RATIO = 0.6
VAL_RATIO = 0.2
TEST_RATIO = 0.2
assert abs(TRAIN_RATIO + VAL_RATIO + TEST_RATIO - 1.0) < 1e-6

# 训练超参
BATCH_SIZE = 128
EPOCHS = 200
LR = 1e-4
WEIGHT_DECAY = 1e-3
DROPOUT = 0.5
IN_PLANES = 64
PATIENCE = 8
MIN_DELTA = 0.05                              # val acc 至少提升 MIN_DELTA 才算进步（百分点）

NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 复现
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.benchmark = True


# =========================
# 信号处理：Doppler / AWGN
# =========================
def compute_doppler_shift(v_kmh, fc_hz):
    c = 3e8
    v = v_kmh / 3.6
    return (v / c) * fc_hz

def apply_doppler_shift(sig_complex_1d, fd_hz, fs_hz):
    #ival = np.arange(sig_complex_1d.shape[0], dtype=np.float64)
    t = np.arange(sig_complex_1d.shape[0], dtype=np.float64) / fs_hz
    return sig_complex_1d * np.exp(1j * 2 * np.pi * fd_hz * t)

def add_awgn_unit_power(sig_complex_1d, snr_db, rng: np.random.Generator):
    """
    假设 sig_complex 已做单位功率归一化（E|s|^2 = 1）
    则噪声功率 = 10^(-snr/10)
    complex noise: real/imag 方差 = noise_power/2
    """
    if snr_db is None:
        return sig_complex_1d
    noise_power = 10.0 ** (-snr_db / 10.0)
    noise_std = math.sqrt(noise_power / 2.0)
    noise = noise_std * (rng.standard_normal(sig_complex_1d.shape) + 1j * rng.standard_normal(sig_complex_1d.shape))
    return sig_complex_1d + noise


# =========================
# domain 解析（你数据命名不同就改这里）
# =========================
_rx_pat = re.compile(r"(?:^|[^a-z0-9])rx\s*[_-]?\s*(\d+)", re.IGNORECASE)
_v_pat1 = re.compile(r"(?:^|[^a-z0-9])v(?:el|speed)?\s*[_-]?\s*(\d+)", re.IGNORECASE)
_v_pat2 = re.compile(r"(\d+)\s*kmh", re.IGNORECASE)

def parse_domain_from_path(path):
    """
    从文件路径解析 domain = (rx, v_kmh)
    默认策略：
      - 在路径各级组件中找 rx\d+
      - 找 v\d+ 或 vel\d+ 或 \d+kmh
    解析失败会抛异常（避免把所有数据都归为 unknown 导致“假多域”）
    """
    parts = re.split(r"[\\/]+", path)
    rx_val = None
    v_val = None

    for p in parts:
        m = _rx_pat.search(p.lower())
        if m:
            rx_val = f"rx{int(m.group(1))}"
        m = _v_pat1.search(p.lower())
        if m:
            v_val = int(m.group(1))
        m = _v_pat2.search(p.lower())
        if m and v_val is None:
            v_val = int(m.group(1))

    if rx_val is None or v_val is None:
        raise ValueError(
            f"Cannot parse (rx,v) from path:\n  {path}\n"
            f"Please modify parse_domain_from_path() to match your folder/file naming."
        )
    return (rx_val, v_val)


# =========================
# .mat 读取：dmrs + txID
# =========================
def read_ltev_mat(file_path):
    """
    读取单个 .mat（HDF5 mat）：
      - rfDataset/dmrs: complex (num_frames, sample_len)
      - rfDataset/txID: uint16 char array (label string)
    返回：
      tx_label(str), dmrs_complex(np.complex64 [N,L])
    """
    with h5py.File(file_path, "r") as f:
        if "rfDataset" not in f:
            raise KeyError(f"{file_path}: missing group 'rfDataset'")
        rf = f["rfDataset"]

        if "dmrs" not in rf:
            raise KeyError(f"{file_path}: missing 'rfDataset/dmrs'")
        dmrs = rf["dmrs"]

        # dmrs 可能是复结构体：fields real/imag
        if hasattr(dmrs, "dtype") and dmrs.dtype.fields is not None and ("real" in dmrs.dtype.fields and "imag" in dmrs.dtype.fields):
            dmrs_struct = dmrs[:]
            dmrs_complex = dmrs_struct["real"] + 1j * dmrs_struct["imag"]
        else:
            # 兜底：若已是复数（少见），或是引用
            dmrs_arr = dmrs[:]
            if np.iscomplexobj(dmrs_arr):
                dmrs_complex = dmrs_arr
            else:
                raise ValueError(f"{file_path}: dmrs dtype not supported: {dmrs_arr.dtype}")

        if "txID" not in rf:
            raise KeyError(f"{file_path}: missing 'rfDataset/txID'")
        txID_uint16 = rf["txID"][:].flatten().astype(np.uint16)
        tx_label = "".join(chr(c) for c in txID_uint16 if c != 0).strip()

    dmrs_complex = np.asarray(dmrs_complex, dtype=np.complex64)
    if dmrs_complex.ndim != 2:
        raise ValueError(f"{file_path}: dmrs_complex must be 2D [N,L], got {dmrs_complex.shape}")
    return tx_label, dmrs_complex


# =========================
# 缓存：读取干净（仅功率归一化）的多域数据
# 结构：clean_data[domain][label] = list of np.float32 IQ frames (L,2)
# =========================
def load_clean_multidomain_dataset(data_root):
    mat_files = glob.glob(os.path.join(data_root, "**", "*.mat"), recursive=True)
    if len(mat_files) == 0:
        raise RuntimeError(f"No .mat found under: {data_root}")

    clean_data = {}      # domain -> label -> list[frame IQ]
    label_set = set()
    sample_len_ref = None

    print(f"[INFO] Found {len(mat_files)} .mat files")
    for fp in tqdm(sorted(mat_files), desc="Loading clean (.mat)"):
        domain = parse_domain_from_path(fp)  # (rx, v)
        tx_label, dmrs_complex = read_ltev_mat(fp)  # [N,L] complex

        if sample_len_ref is None:
            sample_len_ref = dmrs_complex.shape[1]
        else:
            if dmrs_complex.shape[1] != sample_len_ref:
                raise ValueError(f"Sample length mismatch: {fp} has L={dmrs_complex.shape[1]}, expected {sample_len_ref}")

        # 每条帧做功率归一化，再转 IQ
        frames_iq = []
        for i in range(dmrs_complex.shape[0]):
            sig = dmrs_complex[i, :]
            sig = sig / (np.sqrt(np.mean(np.abs(sig) ** 2)) + 1e-12)  # unit power
            iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)  # (L,2)
            frames_iq.append(iq)

        clean_data.setdefault(domain, {}).setdefault(tx_label, []).extend(frames_iq)
        label_set.add(tx_label)

    label_list = sorted(list(label_set))
    label_to_idx = {lab: i for i, lab in enumerate(label_list)}

    # 可选：限制每个(domain,label)最多帧数（保序截断）
    if MAX_SIG_PER_DOMAIN_LABEL is not None:
        for d in list(clean_data.keys()):
            for lab in list(clean_data[d].keys()):
                clean_data[d][lab] = clean_data[d][lab][:MAX_SIG_PER_DOMAIN_LABEL]

    print(f"[INFO] Domains: {len(clean_data)}  | Labels: {len(label_list)}  | Sample_len: {sample_len_ref}")
    # 打印部分 domain 例子
    dom_preview = list(clean_data.keys())[:10]
    print(f"[INFO] Domain preview: {dom_preview}")
    return clean_data, label_to_idx, sample_len_ref


# =========================
# 给定“干净 IQ”，按 SNR 生成“带噪 IQ”（并可加额外 Doppler）
# noisy_data[domain][label] = list of IQ frames (L,2)
# =========================
def make_noisy_dataset(clean_data, label_to_idx, snr_db, seed_base=1234,
                       apply_awgn=True, apply_extra_doppler=False):
    rng = np.random.default_rng(seed_base + int(snr_db * 10) if snr_db is not None else seed_base)
    noisy_data = {}

    for domain, lab_dict in clean_data.items():
        rx, v_kmh = domain
        fd = compute_doppler_shift(v_kmh, FC) if apply_extra_doppler else 0.0

        for lab, frames_iq in lab_dict.items():
            out_frames = []
            for iq in frames_iq:
                sigc = iq[:, 0].astype(np.float32) + 1j * iq[:, 1].astype(np.float32)

                # 额外 Doppler（通常关闭）
                if apply_extra_doppler and fd != 0.0:
                    sigc = apply_doppler_shift(sigc, fd, FS)

                # AWGN（SNR sweep）
                if apply_awgn:
                    sigc = add_awgn_unit_power(sigc, snr_db, rng)

                out_iq = np.stack([sigc.real, sigc.imag], axis=-1).astype(np.float32)
                out_frames.append(out_iq)

            noisy_data.setdefault(domain, {})[lab] = out_frames

    return noisy_data


# =========================
# 划分：每个(domain,label)按时序切 train/val/test
# =========================
def split_by_domain_label_ordered(noisy_data, train_ratio, val_ratio, test_ratio):
    train_dict, val_dict, test_dict = {}, {}, {}
    for domain, lab_dict in noisy_data.items():
        for lab, frames in lab_dict.items():
            n = len(frames)
            if n < 3:
                continue
            n_train = int(n * train_ratio)
            n_val = int(n * val_ratio)
            n_test = n - n_train - n_val
            if n_train <= 0 or n_val <= 0 or n_test <= 0:
                # 太少则跳过这个(domain,label)
                continue

            tr = frames[:n_train]
            va = frames[n_train:n_train + n_val]
            te = frames[n_train + n_val:]

            train_dict.setdefault(domain, {})[lab] = tr
            val_dict.setdefault(domain, {})[lab] = va
            test_dict.setdefault(domain, {})[lab] = te
    return train_dict, val_dict, test_dict


# =========================
# XFR：把 frame list -> blocks（转置后的 block）
# block_transposed shape = (sample_len, block_size, 2)
# =========================
def build_xfr_blocks_single_domain(frames_iq_list, block_size, y_take=1):
    """
    单域：按顺序取 frame，凑足 block_size -> 转置 -> 存 block
    y_take 对单域影响不大，保留接口与多域一致
    """
    blocks = []
    acc = []
    i = 0
    while i < len(frames_iq_list):
        take_n = min(y_take, len(frames_iq_list) - i)
        acc.extend(frames_iq_list[i:i+take_n])
        i += take_n

        while len(acc) >= block_size:
            chunk = acc[:block_size]
            acc = acc[block_size:]

            block = np.array(chunk, dtype=np.float32)          # (block_size, sample_len, 2)
            block_t = np.transpose(block, (1, 0, 2))           # (sample_len, block_size, 2)
            blocks.append(block_t)
    return blocks


def build_xfr_blocks_multidomain_for_label(train_dict, label, domains, block_size, y_take=1):
    """
    多域：对某个 label，从多个 domain 轮询抽样，拼成 block_size -> 转置 -> 存 block
    train_dict[domain][label] = list of frames (ordered)
    """
    # 收集各 domain 的该 label 帧列表（保序）
    dom_lists = []
    dom_keys = []
    for d in domains:
        if d in train_dict and label in train_dict[d] and len(train_dict[d][label]) > 0:
            dom_lists.append(list(train_dict[d][label]))  # copy，后面会 pop(0)
            dom_keys.append(d)

    if len(dom_lists) == 0:
        return []

    blocks = []
    acc = []
    ptr = 0
    num_dom = len(dom_lists)

    while any(len(lst) > 0 for lst in dom_lists):
        di = ptr % num_dom
        lst = dom_lists[di]
        if len(lst) > 0:
            take_n = min(y_take, len(lst))
            # 顺序取 y_take
            for _ in range(take_n):
                acc.append(lst.pop(0))
        ptr += 1

        while len(acc) >= block_size:
            chunk = acc[:block_size]
            acc = acc[block_size:]

            block = np.array(chunk, dtype=np.float32)          # (block_size, sample_len, 2)
            block_t = np.transpose(block, (1, 0, 2))           # (sample_len, block_size, 2)
            blocks.append(block_t)

    return blocks


def expand_blocks_to_rows(x_blocks_t, y_blocks):
    """
    x_blocks_t: np.array (num_blocks, sample_len, block_size, 2)
    y_blocks:   np.array (num_blocks,)
    返回：
      X_rows: (num_blocks*sample_len, block_size, 2)
      y_rows: (num_blocks*sample_len,)
    """
    num_blocks, sample_len, block_size, _ = x_blocks_t.shape
    X_rows = x_blocks_t.reshape(-1, block_size, 2)
    y_rows = np.repeat(y_blocks, sample_len)
    return X_rows, y_rows


# =========================
# RAW：把单域 frames list -> (N, sample_len,2)
# =========================
def build_raw_rows_single_domain(domain_dict, label_to_idx):
    X_list, y_list = [], []
    for lab, frames in domain_dict.items():
        if lab not in label_to_idx:
            continue
        yi = label_to_idx[lab]
        for f in frames:
            X_list.append(f)
            y_list.append(yi)
    if len(X_list) == 0:
        return None, None
    X = np.stack(X_list, axis=0).astype(np.float32)  # (N, sample_len,2)
    y = np.array(y_list, dtype=np.int64)
    return X, y


# =========================
# 模型：1D ResNet18（接受可变长度）
# =========================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, L, 2) -> (B, 2, L)
        x = x.permute(0, 2, 1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return self.fc(x)


# =========================
# 评估（acc + confusion）
# =========================
def evaluate_model(model, loader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_y, all_p = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            _, pred = torch.max(out, 1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
            all_y.extend(yb.cpu().numpy().tolist())
            all_p.extend(pred.cpu().numpy().tolist())
    acc = 100.0 * correct / max(1, total)
    cm = confusion_matrix(all_y, all_p, labels=list(range(num_classes)))
    return acc, cm


def plot_confusion(cm, title, save_path):
    plt.figure(figsize=(7, 6))
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    plt.colorbar()
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()


def plot_curve(x, y, xlabel, ylabel, title, save_path):
    plt.figure(figsize=(7, 5))
    plt.plot(x, y, marker="o")
    plt.grid(True)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()


# =========================
# 构建训练 blocks（多域） + 构建 val/test（单域 RAW & XFR）
# =========================
def build_all_datasets_for_snr(noisy_train, noisy_val, noisy_test, label_to_idx,
                               sample_len, block_size, y_take):
    domains = sorted(list(set(list(noisy_train.keys()) + list(noisy_val.keys()) + list(noisy_test.keys()))))

    # ===== 训练 blocks（多域混合）：按 label 分别轮询 domains 拼 block
    idx_to_label = {i: lab for lab, i in label_to_idx.items()}
    label_list = [idx_to_label[i] for i in range(len(label_to_idx))]

    train_blocks = []
    train_block_labels = []

    for lab in label_list:
        blocks_lab = build_xfr_blocks_multidomain_for_label(
            train_dict=noisy_train,
            label=lab,
            domains=domains,
            block_size=block_size,
            y_take=y_take
        )
        if len(blocks_lab) == 0:
            continue
        yi = label_to_idx[lab]
        train_blocks.extend(blocks_lab)
        train_block_labels.extend([yi] * len(blocks_lab))

    if len(train_blocks) == 0:
        raise RuntimeError("No training blocks were generated. Check BLOCK_SIZE/Y_TAKE and data volume per domain/label.")

    X_train_blocks = np.stack(train_blocks, axis=0).astype(np.float32)          # (nb, sample_len, block_size,2)
    y_train_blocks = np.array(train_block_labels, dtype=np.int64)

    # ===== 单域 val/test：每个 domain 两套版本
    # val_raw[domain] = (X,y) raw frames
    # val_xfr[domain] = (X_rows,y_rows) XFR rows
    val_raw, val_xfr = {}, {}
    test_raw, test_xfr = {}, {}

    for domain in domains:
        # ---- val RAW
        if domain in noisy_val:
            Xv_raw, yv_raw = build_raw_rows_single_domain(noisy_val[domain], label_to_idx)
            if Xv_raw is not None:
                val_raw[domain] = (Xv_raw, yv_raw)

        # ---- test RAW
        if domain in noisy_test:
            Xt_raw, yt_raw = build_raw_rows_single_domain(noisy_test[domain], label_to_idx)
            if Xt_raw is not None:
                test_raw[domain] = (Xt_raw, yt_raw)

        # ---- val XFR (single-domain internal blocks)
        if domain in noisy_val:
            blocks_d, yb_d = [], []
            for lab, frames in noisy_val[domain].items():
                if lab not in label_to_idx:
                    continue
                blocks = build_xfr_blocks_single_domain(frames, block_size=block_size, y_take=y_take)
                if len(blocks) == 0:
                    continue
                yi = label_to_idx[lab]
                blocks_d.extend(blocks)
                yb_d.extend([yi] * len(blocks))

            if len(blocks_d) > 0:
                Xb = np.stack(blocks_d, axis=0).astype(np.float32)
                yb = np.array(yb_d, dtype=np.int64)
                X_rows, y_rows = expand_blocks_to_rows(Xb, yb)
                val_xfr[domain] = (X_rows, y_rows)

        # ---- test XFR (single-domain internal blocks)
        if domain in noisy_test:
            blocks_d, yb_d = [], []
            for lab, frames in noisy_test[domain].items():
                if lab not in label_to_idx:
                    continue
                blocks = build_xfr_blocks_single_domain(frames, block_size=block_size, y_take=y_take)
                if len(blocks) == 0:
                    continue
                yi = label_to_idx[lab]
                blocks_d.extend(blocks)
                yb_d.extend([yi] * len(blocks))

            if len(blocks_d) > 0:
                Xb = np.stack(blocks_d, axis=0).astype(np.float32)
                yb = np.array(yb_d, dtype=np.int64)
                X_rows, y_rows = expand_blocks_to_rows(Xb, yb)
                test_xfr[domain] = (X_rows, y_rows)

    return (X_train_blocks, y_train_blocks), (val_raw, val_xfr), (test_raw, test_xfr), domains


# =========================
# 训练（单模型）：训练用 XFR rows（从多域 blocks 展开）
# 验证：每个 domain 两套版本都评估（RAW/XFR）
# 早停：默认用 “val_xfr 各 domain 平均 acc”
# =========================
def train_one_snr(
        snr_db,
        X_train_blocks, y_train_blocks,
        val_raw, val_xfr,
        test_raw, test_xfr,
        num_classes,
        sample_len,
        block_size,
        save_folder):

    os.makedirs(save_folder, exist_ok=True)
    results_file = os.path.join(save_folder, "results.txt")

    # 展开训练 blocks 为 row-samples（XFR rows）
    X_train_rows, y_train_rows = expand_blocks_to_rows(X_train_blocks, y_train_blocks)
    train_ds = TensorDataset(torch.tensor(X_train_rows, dtype=torch.float32),
                             torch.tensor(y_train_rows, dtype=torch.long))
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True)

    # 为 val/test 构造 DataLoader（按 domain）
    def make_loaders(domain_dict, batch_size):
        loaders = {}
        for dom, (X, y) in domain_dict.items():
            ds = TensorDataset(torch.tensor(X, dtype=torch.float32),
                               torch.tensor(y, dtype=torch.long))
            loaders[dom] = DataLoader(ds, batch_size=batch_size, shuffle=False,
                                      num_workers=NUM_WORKERS, pin_memory=True)
        return loaders

    val_raw_loaders = make_loaders(val_raw, BATCH_SIZE)
    val_xfr_loaders = make_loaders(val_xfr, BATCH_SIZE)
    test_raw_loaders = make_loaders(test_raw, BATCH_SIZE)
    test_xfr_loaders = make_loaders(test_xfr, BATCH_SIZE)

    # 记录参数
    with open(results_file, "w", encoding="utf-8") as f:
        f.write("=== LTE-V Multi-domain Train / Single-domain Eval (RAW vs XFR) ===\n")
        f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"SNR_dB: {snr_db}\n")
        f.write(f"APPLY_AWGN: {APPLY_AWGN}, APPLY_EXTRA_DOPPLER: {APPLY_EXTRA_DOPPLER}\n")
        f.write(f"BLOCK_SIZE: {block_size}, Y_TAKE: {Y_TAKE}\n")
        f.write(f"Sample_len (raw): {sample_len}\n")
        f.write(f"Train rows: {len(train_ds)} (from blocks={len(X_train_blocks)})\n")
        f.write(f"Val domains RAW: {len(val_raw_loaders)}, XFR: {len(val_xfr_loaders)}\n")
        f.write(f"Test domains RAW: {len(test_raw_loaders)}, XFR: {len(test_xfr_loaders)}\n")
        f.write(f"Model: ResNet18_1D, in_planes={IN_PLANES}, dropout={DROPOUT}\n")
        f.write(f"Batch={BATCH_SIZE}, Epochs={EPOCHS}, LR={LR}, WD={WEIGHT_DECAY}, Patience={PATIENCE}, MinDelta={MIN_DELTA}\n")
        f.write("===============================================================\n\n")

    # 模型
    model = ResNet18_1D(num_classes=num_classes, in_planes=IN_PLANES, dropout=DROPOUT).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    best_val = -1.0
    best_wts = None
    patience_cnt = 0

    train_loss_hist = []
    val_xfr_mean_hist = []
    val_raw_mean_hist = []

    # 评估函数：按 domain 求平均 acc（并可保存 cm）
    def eval_domains(loaders, tag, epoch=None):
        if len(loaders) == 0:
            return None, {}
        accs = {}
        cms = {}
        for dom, ld in loaders.items():
            acc, cm = evaluate_model(model, ld, DEVICE, num_classes)
            accs[dom] = acc
            cms[dom] = cm
        mean_acc = float(np.mean(list(accs.values()))) if len(accs) else None
        return mean_acc, (accs, cms)

    for epoch in range(1, EPOCHS + 1):
        model.train()
        running = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            running += loss.item()

        train_loss = running / max(1, len(train_loader))
        train_loss_hist.append(train_loss)

        # 验证：两套版本（单域）
        model.eval()
        val_xfr_mean, (val_xfr_accs, _) = eval_domains(val_xfr_loaders, "val_xfr")
        val_raw_mean, (val_raw_accs, _) = eval_domains(val_raw_loaders, "val_raw")

        val_xfr_mean_hist.append(val_xfr_mean if val_xfr_mean is not None else np.nan)
        val_raw_mean_hist.append(val_raw_mean if val_raw_mean is not None else np.nan)

        log = (f"Epoch {epoch:03d}/{EPOCHS} | "
               f"TrainLoss={train_loss:.4f} | "
               f"ValXFR_mean={val_xfr_mean if val_xfr_mean is not None else -1:.2f}% | "
               f"ValRAW_mean={val_raw_mean if val_raw_mean is not None else -1:.2f}%")
        print(log)
        with open(results_file, "a", encoding="utf-8") as f:
            f.write(log + "\n")

        # 早停：用 val_xfr_mean
        score = val_xfr_mean if val_xfr_mean is not None else -1.0
        if score > best_val + MIN_DELTA:
            best_val = score
            best_wts = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            patience_cnt = 0

            # 保存每次改进的域级结果
            with open(results_file, "a", encoding="utf-8") as f:
                f.write(f"[BEST] epoch={epoch}  ValXFR_mean={best_val:.2f}%\n")
                # 打印各 domain（可选，太长可注释）
                for dom, acc in sorted(val_xfr_accs.items()):
                    f.write(f"    ValXFR {dom}: {acc:.2f}%\n")
        else:
            patience_cnt += 1
            if patience_cnt >= PATIENCE:
                print("[INFO] Early stopping.")
                with open(results_file, "a", encoding="utf-8") as f:
                    f.write("[INFO] Early stopping.\n")
                break

        scheduler.step()

    # 恢复最佳
    if best_wts is not None:
        model.load_state_dict(best_wts)

    # 最终测试：两套版本（单域）
    model.eval()
    test_xfr_mean, (test_xfr_accs, test_xfr_cms) = (None, ({}, {}))
    test_raw_mean, (test_raw_accs, test_raw_cms) = (None, ({}, {}))

    if len(test_xfr_loaders) > 0:
        test_xfr_mean, (test_xfr_accs, test_xfr_cms) = (lambda loaders: (
            float(np.mean([evaluate_model(model, ld, DEVICE, num_classes)[0] for ld in loaders.values()])),
            {dom: evaluate_model(model, ld, DEVICE, num_classes) for dom, ld in loaders.items()}
        ))(test_xfr_loaders)

        # 拆出来
        txfr_accs = {dom: v[0] for dom, v in test_xfr_accs.items()}
        txfr_cms = {dom: v[1] for dom, v in test_xfr_accs.items()}
        test_xfr_accs, test_xfr_cms = txfr_accs, txfr_cms

    if len(test_raw_loaders) > 0:
        test_raw_mean, (test_raw_accs, test_raw_cms) = (lambda loaders: (
            float(np.mean([evaluate_model(model, ld, DEVICE, num_classes)[0] for ld in loaders.values()])),
            {dom: evaluate_model(model, ld, DEVICE, num_classes) for dom, ld in loaders.items()}
        ))(test_raw_loaders)

        traw_accs = {dom: v[0] for dom, v in test_raw_accs.items()}
        traw_cms = {dom: v[1] for dom, v in test_raw_accs.items()}
        test_raw_accs, test_raw_cms = traw_accs, traw_cms

    # 写结果
    with open(results_file, "a", encoding="utf-8") as f:
        f.write("\n=== FINAL TEST RESULTS ===\n")
        f.write(f"TestXFR_mean: {test_xfr_mean if test_xfr_mean is not None else 'N/A'}\n")
        if test_xfr_mean is not None:
            for dom, acc in sorted(test_xfr_accs.items()):
                f.write(f"    TestXFR {dom}: {acc:.2f}%\n")
        f.write(f"TestRAW_mean: {test_raw_mean if test_raw_mean is not None else 'N/A'}\n")
        if test_raw_mean is not None:
            for dom, acc in sorted(test_raw_accs.items()):
                f.write(f"    TestRAW {dom}: {acc:.2f}%\n")

    # 保存模型
    torch.save(model.state_dict(), os.path.join(save_folder, "best_model.pth"))

    # 画训练曲线
    plot_curve(
        x=list(range(1, len(train_loss_hist) + 1)),
        y=train_loss_hist,
        xlabel="Epoch",
        ylabel="Train Loss",
        title=f"SNR={snr_db} Train Loss",
        save_path=os.path.join(save_folder, "train_loss.png")
    )
    plot_curve(
        x=list(range(1, len(val_xfr_mean_hist) + 1)),
        y=val_xfr_mean_hist,
        xlabel="Epoch",
        ylabel="Val XFR Mean Acc (%)",
        title=f"SNR={snr_db} Val XFR Mean Acc",
        save_path=os.path.join(save_folder, "val_xfr_mean_acc.png")
    )
    plot_curve(
        x=list(range(1, len(val_raw_mean_hist) + 1)),
        y=val_raw_mean_hist,
        xlabel="Epoch",
        ylabel="Val RAW Mean Acc (%)",
        title=f"SNR={snr_db} Val RAW Mean Acc",
        save_path=os.path.join(save_folder, "val_raw_mean_acc.png")
    )

    # 保存若干 domain 的混淆矩阵（避免太多文件）
    # 选 acc 最低/最高的各 1 个做示例
    def save_cm_examples(accs, cms, tag):
        if not accs:
            return
        dom_sorted = sorted(accs.items(), key=lambda x: x[1])
        worst_dom = dom_sorted[0][0]
        best_dom = dom_sorted[-1][0]
        plot_confusion(cms[worst_dom], f"{tag} WORST {worst_dom}", os.path.join(save_folder, f"cm_{tag}_worst.png"))
        plot_confusion(cms[best_dom], f"{tag} BEST {best_dom}", os.path.join(save_folder, f"cm_{tag}_best.png"))

    save_cm_examples(test_xfr_accs, test_xfr_cms, "test_xfr")
    save_cm_examples(test_raw_accs, test_raw_cms, "test_raw")

    # 返回测试均值（两套）
    return (test_raw_mean if test_raw_mean is not None else np.nan,
            test_xfr_mean if test_xfr_mean is not None else np.nan)


# =========================
# 主程序：缓存 clean -> 对每个 SNR 构造 noisy -> 构造数据集 -> 训练 -> 汇总曲线
# =========================
def main():
    os.makedirs(SAVE_ROOT, exist_ok=True)
    print("[INFO] DEVICE:", DEVICE)

    # 1) 载入干净多域数据（只做功率归一化，不加噪）
    clean_data, label_to_idx, sample_len = load_clean_multidomain_dataset(DATA_ROOT)
    num_classes = len(label_to_idx)

    # 保存一个 meta
    meta = {
        "data_root": DATA_ROOT,
        "num_domains": len(clean_data),
        "num_classes": num_classes,
        "sample_len_raw": sample_len,
        "block_size_xfr": BLOCK_SIZE,
        "y_take": Y_TAKE,
        "train_ratio": TRAIN_RATIO,
        "val_ratio": VAL_RATIO,
        "test_ratio": TEST_RATIO,
        "apply_awgn": APPLY_AWGN,
        "apply_extra_doppler": APPLY_EXTRA_DOPPLER,
        "fs": FS,
        "fc": FC,
        "snr_list": SNR_LIST,
        "seed": SEED
    }
    with open(os.path.join(SAVE_ROOT, "meta.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2, ensure_ascii=False)

    snr_test_raw_means = []
    snr_test_xfr_means = []

    timestamp_all = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    for snr_db in SNR_LIST:
        print("\n" + "=" * 80)
        print(f"[RUN] SNR = {snr_db} dB")
        print("=" * 80)

        # 2) 生成带噪版本（每个 SNR 单独生成）
        noisy_data = make_noisy_dataset(
            clean_data, label_to_idx, snr_db=snr_db,
            seed_base=777,
            apply_awgn=APPLY_AWGN,
            apply_extra_doppler=APPLY_EXTRA_DOPPLER
        )

        # 3) 划分 train/val/test（每个 domain & label 按时序切）
        noisy_train, noisy_val, noisy_test = split_by_domain_label_ordered(
            noisy_data, TRAIN_RATIO, VAL_RATIO, TEST_RATIO
        )

        # 4) 构造训练 blocks（多域）+ val/test 单域 RAW & XFR
        (X_train_blocks, y_train_blocks), (val_raw, val_xfr), (test_raw, test_xfr), domains = build_all_datasets_for_snr(
            noisy_train, noisy_val, noisy_test,
            label_to_idx=label_to_idx,
            sample_len=sample_len,
            block_size=BLOCK_SIZE,
            y_take=Y_TAKE
        )

        print(f"[INFO] Train blocks: {len(X_train_blocks)}")
        print(f"[INFO] Val domains RAW={len(val_raw)} XFR={len(val_xfr)} | Test domains RAW={len(test_raw)} XFR={len(test_xfr)}")

        # 5) 训练与评估
        run_name = f"{timestamp_all}_LTV_SNR{snr_db}dB_bs{BLOCK_SIZE}_y{Y_TAKE}_cls{num_classes}"
        save_folder = os.path.join(SAVE_ROOT, run_name)
        os.makedirs(save_folder, exist_ok=True)

        test_raw_mean, test_xfr_mean = train_one_snr(
            snr_db=snr_db,
            X_train_blocks=X_train_blocks,
            y_train_blocks=y_train_blocks,
            val_raw=val_raw,
            val_xfr=val_xfr,
            test_raw=test_raw,
            test_xfr=test_xfr,
            num_classes=num_classes,
            sample_len=sample_len,
            block_size=BLOCK_SIZE,
            save_folder=save_folder
        )

        snr_test_raw_means.append(test_raw_mean)
        snr_test_xfr_means.append(test_xfr_mean)

        print(f"[DONE] SNR={snr_db} | TestRAW_mean={test_raw_mean:.2f}% | TestXFR_mean={test_xfr_mean:.2f}% | saved: {save_folder}")

        # 清理
        del noisy_data, noisy_train, noisy_val, noisy_test
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # 6) 汇总曲线（两条：RAW vs XFR）
    plot_curve(
        x=SNR_LIST,
        y=snr_test_raw_means,
        xlabel="SNR (dB)",
        ylabel="Mean Test Acc (%)",
        title="LTE-V Mean Test Acc vs SNR (Single-domain RAW)",
        save_path=os.path.join(SAVE_ROOT, f"{timestamp_all}_snr_vs_acc_raw.png")
    )
    plot_curve(
        x=SNR_LIST,
        y=snr_test_xfr_means,
        xlabel="SNR (dB)",
        ylabel="Mean Test Acc (%)",
        title="LTE-V Mean Test Acc vs SNR (Single-domain XFR)",
        save_path=os.path.join(SAVE_ROOT, f"{timestamp_all}_snr_vs_acc_xfr.png")
    )

    # 保存汇总
    summary = []
    for snr_db, a_raw, a_xfr in zip(SNR_LIST, snr_test_raw_means, snr_test_xfr_means):
        summary.append({"snr_db": snr_db, "test_raw_mean": a_raw, "test_xfr_mean": a_xfr})
    with open(os.path.join(SAVE_ROOT, f"{timestamp_all}_summary.json"), "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print("\n=== FINAL SUMMARY ===")
    for snr_db, a_raw, a_xfr in zip(SNR_LIST, snr_test_raw_means, snr_test_xfr_means):
        print(f"SNR {snr_db:>3} dB | RAW_mean={a_raw:.2f}% | XFR_mean={a_xfr:.2f}%")
    print("[INFO] Plots & logs saved to:", SAVE_ROOT)


if __name__ == "__main__":
    main()


[INFO] DEVICE: cuda
[INFO] Found 72 .mat files


Loading clean (.mat): 100%|██████████| 72/72 [00:05<00:00, 12.95it/s]


[INFO] Domains: 8  | Labels: 9  | Sample_len: 256
[INFO] Domain preview: [('rx1', 10), ('rx1', 20), ('rx2', 10), ('rx2', 20), ('rx3', 10), ('rx3', 20), ('rx4', 10), ('rx4', 20)]

[RUN] SNR = 20 dB
[INFO] Train blocks: 491
[INFO] Val domains RAW=8 XFR=8 | Test domains RAW=8 XFR=8
Epoch 001/200 | TrainLoss=2.0778 | ValXFR_mean=10.16% | ValRAW_mean=10.73%
Epoch 002/200 | TrainLoss=1.1047 | ValXFR_mean=9.52% | ValRAW_mean=11.35%
Epoch 003/200 | TrainLoss=0.5930 | ValXFR_mean=11.07% | ValRAW_mean=11.90%
Epoch 004/200 | TrainLoss=0.3610 | ValXFR_mean=11.24% | ValRAW_mean=12.17%
Epoch 005/200 | TrainLoss=0.2049 | ValXFR_mean=11.40% | ValRAW_mean=11.87%
Epoch 006/200 | TrainLoss=0.1250 | ValXFR_mean=11.39% | ValRAW_mean=12.26%
Epoch 007/200 | TrainLoss=0.0834 | ValXFR_mean=11.65% | ValRAW_mean=12.05%
Epoch 008/200 | TrainLoss=0.0619 | ValXFR_mean=11.40% | ValRAW_mean=12.30%
Epoch 009/200 | TrainLoss=0.0487 | ValXFR_mean=11.46% | ValRAW_mean=11.68%
Epoch 010/200 | TrainLoss=0.0420 | ValXFR_mean

ValueError: too many values to unpack (expected 2)