In [None]:
# plot_ltev_feature_tsne_sample_level_export_mat_and_figs.py
# 样本级（非block聚合）特征 t-SNE：RAW vs XFR，joint-fit 2D/3D
# 输出：4张独立图（PDF+PNG, 无title, 学术风格） + 4份.mat + 1份总.mat

import os
import re
import glob
import numpy as np
import h5py
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from scipy.io import savemat


# ===================== 你需要确认/修改的参数 =====================
DATA_PATH = r"E:/rf_datasets/"  # LTE-V .mat 文件夹

EXP_XFR_DIR = r"./training_results/2026-01-07_19-01-27_LTE-V_XFR_SNR20dB_fd655_classes_9_ResNet"
EXP_RAW_DIR = r"./training_results/2025-11-27_11-25-06_LTE_time_SNR20dB_fd960_classes_9"

OUT_DIR = r"./TSNE"

GROUP_SIZE = 864
TEST_SIZE = 0.25
SPLIT_SEED = 42

# 为了公平对比，建议 True：两边用同一个 SNR/fd 注入生成同源 blocks
USE_COMMON_CHANNEL = True
FS = 5e6
APPLY_DOPPLER = True
APPLY_AWGN = True

# 抽样（样本级每类最多点数，避免t-SNE过慢）
MAX_SAMPLES_PER_CLASS_RAW = 1500
MAX_SAMPLES_PER_CLASS_XFR = 1500

BATCH_SIZE = 1024
NUM_WORKERS = 0

# t-SNE
SEED = 42
PCA_DIM = 50
TSNE_PERPLEXITY = 30
TSNE_ITERS = 1500
TSNE_METRIC = "cosine"
# ===============================================================


# ===================== 论文风格绘图配置 =====================
def set_paper_style():
    import matplotlib as mpl
    mpl.rcParams.update({
        "font.family": "serif",
        "font.size": 9,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.linewidth": 0.8,
        "legend.frameon": False,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.facecolor": "white",
        "axes.facecolor": "white",
    })


def save_tsne_2d_single(emb2, y, out_base, xlim=None, ylim=None, point_size=3, alpha=0.85):
    """
    输出单独的 2D 图：out_base.pdf + out_base.png
    """
    set_paper_style()
    num_classes = int(y.max()) + 1
    cmap = plt.cm.get_cmap("tab20", num_classes)

    fig = plt.figure(figsize=(3.4, 3.0), dpi=400)
    ax = fig.add_subplot(1, 1, 1)
    ax.scatter(emb2[:, 0], emb2[:, 1], c=y, cmap=cmap, s=point_size, alpha=alpha, linewidths=0)

    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")

    # 论文常见：去掉上右边框
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    plt.tight_layout(pad=0.2)

    os.makedirs(os.path.dirname(out_base), exist_ok=True)
    plt.savefig(out_base + ".pdf", bbox_inches="tight", pad_inches=0.02)
    plt.savefig(out_base + ".png", bbox_inches="tight", pad_inches=0.02, dpi=400)
    plt.close(fig)


def save_tsne_3d_single(emb3, y, out_base, xlim=None, ylim=None, zlim=None,
                       view=(18, -60), point_size=3, alpha=0.85, grid_alpha=0.25):
    """
    输出单独的 3D 图：out_base.pdf + out_base.png
    """
    set_paper_style()
    num_classes = int(y.max()) + 1
    cmap = plt.cm.get_cmap("tab20", num_classes)

    fig = plt.figure(figsize=(3.4, 3.0), dpi=400)
    ax = fig.add_subplot(1, 1, 1, projection="3d")
    ax.scatter(emb3[:, 0], emb3[:, 1], emb3[:, 2], c=y, cmap=cmap, s=point_size, alpha=alpha, linewidths=0)

    ax.set_xlabel("t-SNE 1", labelpad=2)
    ax.set_ylabel("t-SNE 2", labelpad=2)
    ax.set_zlabel("t-SNE 3", labelpad=2)

    ax.view_init(elev=view[0], azim=view[1])
    ax.grid(True, alpha=grid_alpha)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    if zlim is not None:
        ax.set_zlim(zlim)

    plt.tight_layout(pad=0.2)

    os.makedirs(os.path.dirname(out_base), exist_ok=True)
    plt.savefig(out_base + ".pdf", bbox_inches="tight", pad_inches=0.02)
    plt.savefig(out_base + ".png", bbox_inches="tight", pad_inches=0.02, dpi=400)
    plt.close(fig)


# ===================== .mat 导出 =====================
def save_embedding_mat_2d(emb2, y, out_path, meta=None):
    """
    保存 2D t-SNE 到 .mat：
      - emb: [N,2]
      - label: [N,1]
      - meta: dict（可选）
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    mdict = {
        "emb": emb2.astype(np.float64),
        "label": y.astype(np.int64).reshape(-1, 1),
    }
    if meta is not None:
        mdict["meta"] = meta
    savemat(out_path, mdict)
    print(f"[OK] Saved mat: {out_path}")


def save_embedding_mat_3d(emb3, y, out_path, meta=None):
    """
    保存 3D t-SNE 到 .mat：
      - emb: [N,3]
      - label: [N,1]
      - meta: dict（可选）
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    mdict = {
        "emb": emb3.astype(np.float64),
        "label": y.astype(np.int64).reshape(-1, 1),
    }
    if meta is not None:
        mdict["meta"] = meta
    savemat(out_path, mdict)
    print(f"[OK] Saved mat: {out_path}")


# ===================== 信号处理 =====================
def apply_doppler_shift(signal_c, fd_hz, fs_hz):
    t = np.arange(signal_c.shape[-1]) / fs_hz
    return signal_c * np.exp(1j * 2 * np.pi * fd_hz * t)


def add_awgn(signal_c, snr_db):
    power = np.mean(np.abs(signal_c) ** 2) + 1e-12
    noise_power = power / (10 ** (snr_db / 10))
    noise = np.sqrt(noise_power / 2) * (np.random.randn(*signal_c.shape) + 1j * np.random.randn(*signal_c.shape))
    return signal_c + noise


def parse_snr_fd_classes(folder_name: str):
    snr = None
    fd = None
    ncls = None
    m = re.search(r"SNR(-?\d+)dB", folder_name)
    if m:
        snr = int(m.group(1))
    m = re.search(r"fd(\d+)", folder_name)
    if m:
        fd = int(m.group(1))
    m = re.search(r"classes[_-](\d+)", folder_name)
    if m:
        ncls = int(m.group(1))
    return snr, fd, ncls


# ===================== LTE-V blocks =====================
def load_blocks_ltev(mat_folder,
                     group_size=864,
                     apply_doppler=True,
                     fd_hz=655,
                     apply_awgn=True,
                     snr_db=20,
                     fs=5e6):
    mat_files = glob.glob(os.path.join(mat_folder, "*.mat"))
    if len(mat_files) == 0:
        raise RuntimeError(f"在 {mat_folder} 未找到 .mat 文件")

    X_files, y_files, label_set = [], [], set()

    print(f"[INFO] Found {len(mat_files)} .mat files")
    for file in tqdm(mat_files, desc="Reading LTE-V .mat"):
        with h5py.File(file, "r") as f:
            rfDataset = f["rfDataset"]
            dmrs = rfDataset["dmrs"][:]
            dmrs_complex = dmrs["real"] + 1j * dmrs["imag"]
            txID_uint16 = rfDataset["txID"][:].flatten()
            tx_id = "".join(chr(c) for c in txID_uint16 if c != 0)

            processed = []
            for i in range(dmrs_complex.shape[0]):
                sig = dmrs_complex[i, :].astype(np.complex64)
                sig = sig / (np.sqrt(np.mean(np.abs(sig) ** 2)) + 1e-12)

                if apply_doppler:
                    sig = apply_doppler_shift(sig, fd_hz, fs)

                if apply_awgn:
                    sig = add_awgn(sig, snr_db)

                iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)
                processed.append(iq)

            processed = np.array(processed, dtype=np.float32)
            X_files.append(processed)
            y_files.append(tx_id)
            label_set.add(tx_id)

    label_list = sorted(list(label_set))
    label_to_idx = {lb: i for i, lb in enumerate(label_list)}

    raw_blocks_list = []
    xfr_blocks_list = []
    y_blocks_list = []

    for label in label_list:
        files_idx = [i for i, y in enumerate(y_files) if y == label]
        num_files = len(files_idx)
        if num_files == 0:
            continue

        samples_per_file = group_size // num_files
        if samples_per_file == 0:
            continue

        min_samples = min([X_files[i].shape[0] for i in files_idx])
        max_groups = min_samples // samples_per_file
        if max_groups == 0:
            continue

        for gi in range(max_groups):
            pieces = []
            for fi in files_idx:
                s = gi * samples_per_file
                e = s + samples_per_file
                pieces.append(X_files[fi][s:e])
            big_block = np.concatenate(pieces, axis=0)  # (G,L,2)

            raw_blocks_list.append(big_block)  # (G,L,2)
            xfr_blocks_list.append(np.transpose(big_block, (1, 0, 2)))  # (L,G,2)
            y_blocks_list.append(label_to_idx[label])

    if len(raw_blocks_list) == 0:
        raise RuntimeError("未生成任何 block，请检查 group_size 或数据组织。")

    raw_blocks = np.stack(raw_blocks_list, axis=0)  # (B,G,L,2)
    xfr_blocks = np.stack(xfr_blocks_list, axis=0)  # (B,L,G,2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    print(f"[INFO] blocks={raw_blocks.shape[0]}, group_size={raw_blocks.shape[1]}, sample_len={raw_blocks.shape[2]}, classes={len(label_list)}")
    return raw_blocks, xfr_blocks, y_blocks


# ===================== ResNet18_1D extract_features =====================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = out + identity
        return self.relu(out)


class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def extract_features(self, x):
        x = x.permute(0, 2, 1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return x

    def forward(self, x):
        feat = self.extract_features(x)
        return self.fc(feat)


def strip_module_prefix(state_dict):
    if any(k.startswith("module.") for k in state_dict.keys()):
        return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
    return state_dict


def find_fold_models(exp_dir):
    paths = glob.glob(os.path.join(exp_dir, "best_model_fold*.pth"))
    if len(paths) == 0:
        paths = glob.glob(os.path.join(exp_dir, "model_fold*.pth"))
    if len(paths) == 0:
        raise RuntimeError(f"在 {exp_dir} 未找到 best_model_fold*.pth 或 model_fold*.pth")

    def fold_key(p):
        m = re.search(r"fold(\d+)", os.path.basename(p))
        return int(m.group(1)) if m else 999

    return sorted(paths, key=fold_key)


# ===================== 样本级 dataset =====================
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)


def stratified_sample_indices(y, max_per_class, seed=42):
    rng = np.random.default_rng(seed)
    keep = []
    for c in np.unique(y):
        idx = np.where(y == c)[0]
        if len(idx) > max_per_class:
            idx = rng.choice(idx, size=max_per_class, replace=False)
        keep.append(idx)
    keep = np.concatenate(keep)
    rng.shuffle(keep)
    return keep


def extract_features_foldmean(model_paths, X, y, device):
    """
    对同一批样本，逐fold提特征后做 fold-mean，返回 (N,512)
    """
    ds = SimpleDataset(X, y)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    feats_folds = []
    for mp in model_paths:
        state = strip_module_prefix(torch.load(mp, map_location=device))
        in_planes = state["conv1.weight"].shape[0]
        num_classes = state["fc.weight"].shape[0]

        model = ResNet18_1D(num_classes=num_classes, in_planes=in_planes, dropout=0.0).to(device)
        model.load_state_dict(state, strict=True)
        model.eval()

        feats = []
        with torch.no_grad():
            for xb, _ in tqdm(loader, desc=f"Extract {os.path.basename(mp)}", leave=False):
                xb = xb.to(device)
                f = model.extract_features(xb).cpu().numpy()
                feats.append(f)
        feats = np.concatenate(feats, axis=0)  # (N,512)
        feats_folds.append(feats)

    feats_folds = np.stack(feats_folds, axis=0)  # (F,N,512)
    return feats_folds.mean(axis=0)              # (N,512)


def tsne_joint(feat_a, feat_b, n_components, seed=42):
    """
    joint-fit：拼接后一次拟合，保证同一坐标系
    """
    X = np.concatenate([feat_a, feat_b], axis=0).astype(np.float32)

    pca_dim_eff = min(PCA_DIM, X.shape[1])
    X_pca = PCA(n_components=pca_dim_eff, random_state=seed).fit_transform(X)

    # 兼容不同 sklearn 版本：n_iter vs max_iter
    tsne_kwargs = dict(
        n_components=n_components,
        perplexity=TSNE_PERPLEXITY,
        init="pca",
        learning_rate="auto",
        metric=TSNE_METRIC,
        random_state=seed,
        verbose=1
    )
    try:
        tsne = TSNE(**tsne_kwargs, n_iter=TSNE_ITERS)
    except TypeError:
        tsne = TSNE(**tsne_kwargs, max_iter=TSNE_ITERS)

    emb = tsne.fit_transform(X_pca)
    return emb[:feat_a.shape[0]], emb[feat_a.shape[0]:]


def main():
    # 固定随机性（尽量可复现）
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(OUT_DIR, exist_ok=True)

    xfr_name = os.path.basename(EXP_XFR_DIR)
    raw_name = os.path.basename(EXP_RAW_DIR)

    snr_xfr, fd_xfr, _ = parse_snr_fd_classes(xfr_name)
    snr_raw, fd_raw, _ = parse_snr_fd_classes(raw_name)

    if USE_COMMON_CHANNEL:
        snr_use = snr_xfr if snr_xfr is not None else 20
        fd_use = fd_xfr if fd_xfr is not None else 655
        print(f"[INFO] Common channel: SNR={snr_use} dB, fd={fd_use} Hz")
    else:
        snr_use = snr_xfr or 20
        fd_use = fd_xfr or 655

    # 构造 blocks（同源）
    raw_blocks, xfr_blocks, y_blocks = load_blocks_ltev(
        DATA_PATH,
        group_size=GROUP_SIZE,
        apply_doppler=APPLY_DOPPLER,
        fd_hz=fd_use,
        apply_awgn=APPLY_AWGN,
        snr_db=snr_use,
        fs=FS
    )

    num_blocks = raw_blocks.shape[0]
    sample_len = raw_blocks.shape[2]

    # block split（与你训练一致）
    block_idx = np.arange(num_blocks)
    _, test_idx, _, _ = train_test_split(
        block_idx, y_blocks, test_size=TEST_SIZE, stratify=y_blocks, random_state=SPLIT_SEED
    )
    sel_blocks = test_idx
    y_sel_blocks = y_blocks[sel_blocks]

    # ====== 组装样本级数据 ======
    # RAW samples: (nb*G, L,2)
    raw_sel = raw_blocks[sel_blocks]  # (nb,G,L,2)
    X_raw = raw_sel.reshape(-1, sample_len, 2)
    y_raw = np.repeat(y_sel_blocks, raw_sel.shape[1])

    # XFR samples: (nb*L, G,2)
    xfr_sel = xfr_blocks[sel_blocks]  # (nb,L,G,2)
    X_xfr = xfr_sel.reshape(-1, raw_sel.shape[1], 2)
    y_xfr = np.repeat(y_sel_blocks, xfr_sel.shape[1])

    # ====== 样本级分层抽样（否则太大） ======
    idx_raw = stratified_sample_indices(y_raw, MAX_SAMPLES_PER_CLASS_RAW, seed=SEED)
    idx_xfr = stratified_sample_indices(y_xfr, MAX_SAMPLES_PER_CLASS_XFR, seed=SEED)

    X_raw = X_raw[idx_raw]
    y_raw = y_raw[idx_raw]
    X_xfr = X_xfr[idx_xfr]
    y_xfr = y_xfr[idx_xfr]

    print(f"[INFO] sample-level selected: RAW={len(y_raw)}, XFR={len(y_xfr)}")

    # ====== fold models ======
    raw_models = find_fold_models(EXP_RAW_DIR)
    xfr_models = find_fold_models(EXP_XFR_DIR)

    print(f"[INFO] RAW folds: {len(raw_models)}, XFR folds: {len(xfr_models)}")

    # ====== fold-mean features ======
    feat_raw = extract_features_foldmean(raw_models, X_raw, y_raw, device)
    feat_xfr = extract_features_foldmean(xfr_models, X_xfr, y_xfr, device)

    # ====== joint t-SNE ======
    emb2_raw, emb2_xfr = tsne_joint(feat_raw, feat_xfr, n_components=2, seed=SEED)
    emb3_raw, emb3_xfr = tsne_joint(feat_raw, feat_xfr, n_components=3, seed=SEED)

    # ====== 统一坐标范围（论文对比更规范） ======
    x2_all = np.concatenate([emb2_raw[:, 0], emb2_xfr[:, 0]])
    y2_all = np.concatenate([emb2_raw[:, 1], emb2_xfr[:, 1]])
    xlim2 = (float(x2_all.min()), float(x2_all.max()))
    ylim2 = (float(y2_all.min()), float(y2_all.max()))

    x3_all = np.concatenate([emb3_raw[:, 0], emb3_xfr[:, 0]])
    y3_all = np.concatenate([emb3_raw[:, 1], emb3_xfr[:, 1]])
    z3_all = np.concatenate([emb3_raw[:, 2], emb3_xfr[:, 2]])
    xlim3 = (float(x3_all.min()), float(x3_all.max()))
    ylim3 = (float(y3_all.min()), float(y3_all.max()))
    zlim3 = (float(z3_all.min()), float(z3_all.max()))

    # ====== 输出 4 张独立图（PDF + PNG，无title） ======
    save_tsne_2d_single(emb2_raw, y_raw, os.path.join(OUT_DIR, "RAW_2D"), xlim=xlim2, ylim=ylim2)
    save_tsne_2d_single(emb2_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_2D"), xlim=xlim2, ylim=ylim2)

    # 3D 两张保持同 view 与坐标范围
    view = (18, -60)
    save_tsne_3d_single(emb3_raw, y_raw, os.path.join(OUT_DIR, "RAW_3D"), xlim=xlim3, ylim=ylim3, zlim=zlim3, view=view)
    save_tsne_3d_single(emb3_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_3D"), xlim=xlim3, ylim=ylim3, zlim=zlim3, view=view)

    print("[OK] Saved 4 individual figures (PDF+PNG).")

    # ====== 输出 4 份 .mat（分别对应4张图） ======
    meta = {
        "DATA_PATH": DATA_PATH,
        "EXP_RAW_DIR": EXP_RAW_DIR,
        "EXP_XFR_DIR": EXP_XFR_DIR,
        "SNR_use_dB": snr_use,
        "fd_use_Hz": fd_use,
        "USE_COMMON_CHANNEL": int(USE_COMMON_CHANNEL),
        "GROUP_SIZE": GROUP_SIZE,
        "TEST_SIZE": TEST_SIZE,
        "SPLIT_SEED": SPLIT_SEED,
        "MAX_SAMPLES_PER_CLASS_RAW": MAX_SAMPLES_PER_CLASS_RAW,
        "MAX_SAMPLES_PER_CLASS_XFR": MAX_SAMPLES_PER_CLASS_XFR,
        "PCA_DIM": PCA_DIM,
        "TSNE_PERPLEXITY": TSNE_PERPLEXITY,
        "TSNE_ITERS": TSNE_ITERS,
        "TSNE_METRIC": TSNE_METRIC,
        "SEED": SEED,
    }

    save_embedding_mat_2d(emb2_raw, y_raw, os.path.join(OUT_DIR, "RAW_2D_tsne.mat"), meta=meta)
    save_embedding_mat_2d(emb2_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_2D_tsne.mat"), meta=meta)
    save_embedding_mat_3d(emb3_raw, y_raw, os.path.join(OUT_DIR, "RAW_3D_tsne.mat"), meta=meta)
    save_embedding_mat_3d(emb3_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_3D_tsne.mat"), meta=meta)

    # ====== 额外：输出一个总 .mat（包含四个嵌入和标签） ======
    all_mat_path = os.path.join(OUT_DIR, "LTEV_tsne_all.mat")
    savemat(all_mat_path, {
        "emb2_raw": emb2_raw.astype(np.float64),
        "emb2_xfr": emb2_xfr.astype(np.float64),
        "emb3_raw": emb3_raw.astype(np.float64),
        "emb3_xfr": emb3_xfr.astype(np.float64),
        "y_raw": y_raw.astype(np.int64).reshape(-1, 1),
        "y_xfr": y_xfr.astype(np.int64).reshape(-1, 1),
        "meta": meta
    })
    print(f"[OK] Saved mat: {all_mat_path}")


if __name__ == "__main__":
    main()


In [None]:
# train_wisig_time_SNR20_savepth.py
# ManySig (WiSig) 单个SNR训练：保存 best_model_fold*.pth 供后续 t-SNE 特征提取使用

from joblib import load
import numpy as np
import os
from data_utilities import *
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix
import seaborn as sns


# ================== 数据加载与划分 ==================
dataset_name = "ManySig"
dataset_path = "../ManySig.pkl/"

compact_dataset = load_compact_pkl_dataset(dataset_path, dataset_name)

print("数据集发射机数量：", len(compact_dataset["tx_list"]), "具体为：", compact_dataset["tx_list"])
print("数据集接收机数量：", len(compact_dataset["rx_list"]), "具体为：", compact_dataset["rx_list"])
print("数据集采集天数：", len(compact_dataset["capture_date_list"]), "具体为：", compact_dataset["capture_date_list"])

tx_list = compact_dataset["tx_list"]
rx_list = compact_dataset["rx_list"]

train_dates = ["2021_03_15"]
test_dates  = ["2021_03_01"]
equalized = 0

X_train, y_train, X_test, y_test = preprocess_dataset_for_classification_cross_date(
    compact_dataset, tx_list, rx_list, train_dates, test_dates, max_sig=None, equalized=equalized
)

print("训练集所选日期：", train_dates, "测试集所选日期：", test_dates)
print("X_train shape:", X_train.shape, "y_train:", y_train.shape)
print("X_test  shape:", X_test.shape,  "y_test :", y_test.shape)


# ================== 信号处理参数 ==================
fs = 20e6
fc = 2.4e9
v_kmh = 120
ADD_NOISE = True
ADD_DOPPLER = True

def compute_doppler_shift(v, fc):
    c = 3e8
    v = v / 3.6
    return (v / c) * fc

fd = compute_doppler_shift(v_kmh, fc)
print(f"[INFO] Doppler fd={fd:.2f} Hz")

def add_doppler_shift(signal, fd, fs):
    num_samples = signal.shape[-1]
    t = np.arange(num_samples) / fs
    return signal * np.exp(1j * 2 * np.pi * fd * t)

def add_complex_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal) ** 2) + 1e-12
    snr_linear = 10 ** (snr_db / 10)
    noise_power = signal_power / snr_linear
    noise_std = np.sqrt(noise_power / 2)
    noise = np.random.normal(0, noise_std, signal.shape) + 1j * np.random.normal(0, noise_std, signal.shape)
    return signal + noise

def preprocess_iq_data(data_real_imag, snr_db=None, fd=None, fs=None, add_noise=True, add_doppler=True):
    """
    输入: [N, L, 2] float
    输出: [N, L, 2] float
    流程: 功率归一化 -> Doppler -> AWGN
    """
    if add_noise and snr_db is None:
        raise ValueError("add_noise=True 时必须提供 snr_db")
    if add_doppler and (fd is None or fs is None):
        raise ValueError("add_doppler=True 时必须提供 fd/fs")

    data_complex = data_real_imag[..., 0] + 1j * data_real_imag[..., 1]
    processed = []

    for sig in data_complex:
        sig = sig / (np.sqrt(np.mean(np.abs(sig) ** 2)) + 1e-12)
        cur = sig

        if add_doppler:
            cur = add_doppler_shift(cur, fd, fs)
        if add_noise:
            cur = add_complex_awgn(cur, snr_db)

        processed.append(cur)

    processed = np.array(processed)
    return np.stack([processed.real, processed.imag], axis=-1).astype(np.float32)


# ================== ResNet18_1D 模型 ==================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = out + identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # [B,L,2] -> [B,2,L]
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return self.fc(x)


# ================== 训练配置 ==================
SNR_dB = 0                       # 你要补训练的 RAW IQ 就固定 20 dB
batch_size = 64
num_epochs = 200
learning_rate = 1e-4
weight_decay = 1e-3
in_planes = 64
dropout = 0.5
patience = 5
n_splits = 5

num_classes = len(np.unique(y_train))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] device={device}, num_classes={num_classes}")

SAVE_ROOT = "./training_results"
os.makedirs(SAVE_ROOT, exist_ok=True)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
script_name = "wisig_time"
folder_name = f"{timestamp}_{script_name}_SNR{SNR_dB}dB_fd{int(fd)}_classes_{num_classes}_ResNet18"
save_folder = os.path.join(SAVE_ROOT, folder_name)
os.makedirs(save_folder, exist_ok=True)

results_file = os.path.join(save_folder, "results.txt")
with open(results_file, "w", encoding="utf-8") as f:
    f.write("=== Experiment Summary ===\n")
    f.write(f"Timestamp: {timestamp}\n")
    f.write(f"SNR: {SNR_dB} dB\n")
    f.write(f"fd: {fd:.4f} Hz\n")
    f.write(f"equalized: {equalized}\n")
    f.write(f"train_dates: {train_dates}\n")
    f.write(f"test_dates : {test_dates}\n")
    f.write(f"num_classes: {num_classes}\n\n")

print(f"[INFO] save_folder={save_folder}")


# ================== 数据处理（固定20dB） ==================
X_train_processed = preprocess_iq_data(
    X_train, snr_db=SNR_dB, fd=fd, fs=fs, add_noise=ADD_NOISE, add_doppler=ADD_DOPPLER
)
X_test_processed = preprocess_iq_data(
    X_test, snr_db=SNR_dB, fd=fd, fs=fs, add_noise=ADD_NOISE, add_doppler=ADD_DOPPLER
)

train_dataset = TensorDataset(torch.tensor(X_train_processed, dtype=torch.float32),
                              torch.tensor(y_train, dtype=torch.long))
test_dataset  = TensorDataset(torch.tensor(X_test_processed, dtype=torch.float32),
                              torch.tensor(y_test, dtype=torch.long))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)

def moving_average(x, w=5):
    x = np.array(x, dtype=np.float32)
    if len(x) < w:
        w = max(1, len(x))
    return np.convolve(x, np.ones(w), "valid") / w


# ================== KFold 训练并保存 best_model_fold*.pth ==================
fold_val_accs = []
fold_test_accs = []

for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset), start=1):
    print(f"\n=== Fold {fold}/{n_splits} ===")

    train_subset = Subset(train_dataset, train_idx)
    val_subset   = Subset(train_dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader   = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=False)

    model = ResNet18_1D(num_classes=num_classes, in_planes=in_planes, dropout=dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    best_val_acc = 0.0
    best_state = None
    patience_cnt = 0

    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(1, num_epochs + 1):
        # ---- train ----
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for xb, yb in tqdm(train_loader, desc=f"Fold{fold} Epoch{epoch}/{num_epochs}", leave=False):
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, pred = torch.max(out, 1)
            total += yb.size(0)
            correct += (pred == yb).sum().item()

        tr_loss = running_loss / max(1, len(train_loader))
        tr_acc = 100.0 * correct / max(1, total)
        train_losses.append(tr_loss)
        train_accs.append(tr_acc)

        # ---- val ----
        model.eval()
        vloss, vcorrect, vtotal = 0.0, 0, 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = criterion(out, yb)
                vloss += loss.item()
                _, pred = torch.max(out, 1)
                vtotal += yb.size(0)
                vcorrect += (pred == yb).sum().item()

        va_loss = vloss / max(1, len(val_loader))
        va_acc = 100.0 * vcorrect / max(1, vtotal)
        val_losses.append(va_loss)
        val_accs.append(va_acc)

        log = (f"Fold{fold} Epoch{epoch} | "
               f"TrainAcc={tr_acc:.2f}% ValAcc={va_acc:.2f}% | "
               f"TrainLoss={tr_loss:.4f} ValLoss={va_loss:.4f}")
        print(log)
        with open(results_file, "a", encoding="utf-8") as f:
            f.write(log + "\n")

        # ---- early stop on ValAcc (更符合你后续挑 best_model) ----
        if va_acc > best_val_acc + 0.01:
            best_val_acc = va_acc
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print("[INFO] Early stopping triggered.")
                with open(results_file, "a", encoding="utf-8") as f:
                    f.write("[INFO] Early stopping triggered.\n")
                break

        scheduler.step()

    # 保存 best_model_fold*.pth（你后续 t-SNE 需要）
    if best_state is None:
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    best_path = os.path.join(save_folder, f"best_model_fold{fold}.pth")
    torch.save(best_state, best_path)
    print(f"[OK] Saved: {best_path}")

    # 也保存最终模型（可选）
    last_path = os.path.join(save_folder, f"model_fold{fold}.pth")
    torch.save({k: v.detach().cpu() for k, v in model.state_dict().items()}, last_path)

    # 画训练曲线（无须改你原先风格）
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.plot(moving_average(train_losses), "--", label="Train Loss Smooth")
    plt.plot(moving_average(val_losses), "--", label="Val Loss Smooth")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(save_folder, f"fold_{fold}_loss_curve.png"))
    plt.close()

    plt.figure()
    plt.plot(train_accs, label="Train Acc")
    plt.plot(val_accs, label="Val Acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)")
    plt.legend(); plt.grid(True)
    plt.savefig(os.path.join(save_folder, f"fold_{fold}_acc_curve.png"))
    plt.close()

    # 测试集评估：用 best_state
    model.load_state_dict(best_state, strict=True)
    model.eval()
    test_preds, test_true = [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            _, pred = torch.max(out, 1)
            test_preds.append(pred.cpu().numpy())
            test_true.append(yb.cpu().numpy())
    test_preds = np.concatenate(test_preds)
    test_true = np.concatenate(test_true)
    test_acc = 100.0 * np.mean(test_preds == test_true)

    fold_val_accs.append(best_val_acc)
    fold_test_accs.append(test_acc)

    print(f"[RESULT] Fold{fold}: BestValAcc={best_val_acc:.2f}% TestAcc={test_acc:.2f}%")
    with open(results_file, "a", encoding="utf-8") as f:
        f.write(f"[RESULT] Fold{fold}: BestValAcc={best_val_acc:.2f}% TestAcc={test_acc:.2f}%\n")

    cm = confusion_matrix(test_true, test_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Predicted"); plt.ylabel("True")
    plt.savefig(os.path.join(save_folder, f"fold_{fold}_confusion_matrix.png"))
    plt.close()

# 汇总
with open(results_file, "a", encoding="utf-8") as f:
    f.write("\n=== Overall Summary ===\n")
    f.write(f"Avg Best Val Acc: {np.mean(fold_val_accs):.2f} ± {np.std(fold_val_accs):.2f}\n")
    f.write(f"Avg Test Acc    : {np.mean(fold_test_accs):.2f} ± {np.std(fold_test_accs):.2f}\n")

print("\n[OK] Training done.")
print(f"[OK] Models saved in: {save_folder}")


In [None]:
# plot_wisig_tsne_train_test_overlay2d_export_mat.py
# WiSig (ManySig.pkl) 训练集 vs 测试集：2D t-SNE（RAW 与 XFR 各一张 overlay）
# 输出：
#   1) RAW_train_test_2D.pdf/.png
#   2) XFR_train_test_2D.pdf/.png
#   3) WISIG_train_test_tsne_2D_all.mat  (包含RAW/XFR train/test嵌入与标签/元信息)

import os
import re
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
from matplotlib.patches import Ellipse
import numpy.linalg as LA

from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from scipy.io import savemat

# 你项目里已有
from data_utilities import load_compact_pkl_dataset


# ===================== 参数区（按你实际情况改） =====================
DATASET_NAME = "ManySig"
DATASET_PATH = r"../ManySig.pkl/"

EXP_XFR_DIR = r"./training_results/2025-11-28_19-42-37_wisig_XFR_SNR20dB_fd266_classes_6_ResNet"
EXP_RAW_DIR = r"./training_results/2026-01-23_16-40-16_wisig_time_SNR20dB_fd266_classes_6_ResNet18"

OUT_DIR = r"./TSNE_WISIG_20dB_TRAIN_TEST"

# 数据构造参数（必须与训练一致）
EQUALIZED = 0
MAX_SIG = None
BLOCK_SIZE = 240     # 关键：XFR模型训练用的 input_length（你已确认=240）
Y_PER_RX = 10

TRAIN_DATES = ["2021_03_15"]
TEST_DATES  = ["2021_03_01"]

# 信道注入（与你训练一致；若训练时没开，请置 False）
USE_COMMON_CHANNEL = True
FS = 20e6
FC = 2.4e9
VELOCITY_KMH = 120
APPLY_DOPPLER = True
APPLY_AWGN = True

# 抽样：每个 split 每类最多点数（t-SNE太大很慢）
MAX_SAMPLES_PER_CLASS_RAW_TRAIN = 1200
MAX_SAMPLES_PER_CLASS_RAW_TEST  = 1200
MAX_SAMPLES_PER_CLASS_XFR_TRAIN = 1200
MAX_SAMPLES_PER_CLASS_XFR_TEST  = 1200

# 特征提取 batch
BATCH_SIZE = 512
NUM_WORKERS = 0

# t-SNE
SEED = 42
PCA_DIM = 50
TSNE_PERPLEXITY = 30
TSNE_ITERS = 1500
TSNE_METRIC = "cosine"
# ===============================================================


# ===================== 绘图风格（学术风） =====================
def set_paper_style():
    import matplotlib as mpl
    mpl.rcParams.update({
        "font.family": "serif",
        "font.size": 9,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.linewidth": 0.8,
        "legend.frameon": False,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.facecolor": "white",
        "axes.facecolor": "white",
    })


def _subsample(idx, max_n, rng):
    if (max_n is None) or (len(idx) <= max_n):
        return idx
    return rng.choice(idx, size=max_n, replace=False)

def _add_cov_ellipse(ax, pts2d, edgecolor, linestyle="-", n_std=2.0, lw=0.9, alpha=0.9, zorder=2):
    """
    画2D协方差椭圆：n_std=2 相当于“约95%”范围（对高斯近似）。
    pts2d: (N,2)
    """
    if pts2d is None or len(pts2d) < 10:
        return
    mu = np.mean(pts2d, axis=0)
    cov = np.cov(pts2d.T)
    if not np.all(np.isfinite(cov)):
        return
    vals, vecs = LA.eigh(cov)
    order = vals.argsort()[::-1]
    vals = vals[order]
    vecs = vecs[:, order]
    if np.any(vals <= 1e-12):
        return

    # 椭圆主轴长度
    width, height = 2 * n_std * np.sqrt(vals[0]), 2 * n_std * np.sqrt(vals[1])
    angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))

    e = Ellipse(xy=mu, width=width, height=height, angle=angle,
                facecolor="none", edgecolor=edgecolor, linestyle=linestyle,
                linewidth=lw, alpha=alpha)
    e.set_zorder(zorder)
    ax.add_patch(e)

def save_tsne_2d_train_test_overlay(
    emb_train, y_train, emb_test, y_test,
    class_names, out_base,
    xlim=None, ylim=None,
    # ====== 关键可调参数（建议保留默认） ======
    plot_max_per_class_train=400,
    plot_max_per_class_test=400,
    train_marker=".",
    train_size=14,
    train_alpha=0.9,
    test_marker="^",
    test_size=14,
    test_alpha=0.90,
    test_edgecolor="k",
    test_lw=0.25,
    draw_ellipses=True,
    ellipse_nstd=2.0
):
    """
    同一张图：Train vs Test overlay（更抗过绘制版本）
      - Train: 极淡小点（背景密度云）
      - Test : 大三角 + 黑描边（前景）
      - 类别用 TXk 文本标注（默认不再重复画类别legend）
      - 可选协方差椭圆：Train 实线，Test 虚线
    """
    set_paper_style()
    rng = np.random.default_rng(SEED)

    num_classes = int(max(y_train.max(), y_test.max())) + 1
    cmap = plt.cm.get_cmap("tab10", num_classes)

    fig = plt.figure(figsize=(3.6, 3.1), dpi=450)
    ax = fig.add_subplot(1, 1, 1)

    for c in range(num_classes):
        col = cmap(c)

        idx_tr_all = np.where(y_train == c)[0]
        idx_te_all = np.where(y_test == c)[0]

        # 仅用于“画图显示”的二次下采样（t-SNE拟合已完成，不受影响）
        idx_tr = _subsample(idx_tr_all, plot_max_per_class_train, rng)
        idx_te = _subsample(idx_te_all, plot_max_per_class_test, rng)

        # Train：背景云（极淡）
        if len(idx_tr) > 0:
            ax.scatter(
                emb_train[idx_tr, 0], emb_train[idx_tr, 1],
                s=train_size, c=[col], alpha=train_alpha,
                marker=train_marker, linewidths=0,
                zorder=1, rasterized=True
            )

        # Test：前景（三角+描边）
        if len(idx_te) > 0:
            ax.scatter(
                emb_test[idx_te, 0], emb_test[idx_te, 1],
                s=test_size, c=[col], alpha=test_alpha,
                marker=test_marker, linewidths=test_lw,
                edgecolors=test_edgecolor,
                zorder=3, rasterized=True
            )

        # 协方差椭圆（用全量点算，更稳定；显示更清晰）
        if draw_ellipses:
            if len(idx_tr_all) > 15:
                _add_cov_ellipse(ax, emb_train[idx_tr_all], edgecolor=col, linestyle="-",
                                 n_std=ellipse_nstd, lw=0.9, alpha=0.85, zorder=2)
            if len(idx_te_all) > 15:
                _add_cov_ellipse(ax, emb_test[idx_te_all], edgecolor=col, linestyle="--",
                                 n_std=ellipse_nstd, lw=0.9, alpha=0.85, zorder=2)

        # 类标注：用 train+test 的“中位数”更鲁棒（比均值不容易被散点拖偏）
        pts = []
        if len(idx_tr_all) > 0:
            pts.append(emb_train[idx_tr_all])
        if len(idx_te_all) > 0:
            pts.append(emb_test[idx_te_all])
        if len(pts) > 0:
            pts = np.vstack(pts)
            cx, cy = np.median(pts, axis=0)

            txt = ax.text(
                cx, cy, class_names[c],
                fontsize=9, color="black",
                ha="center", va="center", zorder=4
            )
            txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="white")])

    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    # 只保留 Train/Test legend（类别已在图中标注TX1~TX6）
    from matplotlib.lines import Line2D
    split_handles = [
        Line2D([0], [0], marker=train_marker, color="none", markerfacecolor="gray",
               markersize=5, label="date 03_15"),
        Line2D([0], [0], marker=test_marker, color="none", markerfacecolor="gray",
               markeredgecolor="k", markersize=6, label="date 03_01"),
    ]
    ax.legend(handles=split_handles, loc="upper right", fontsize=8, handletextpad=0.4, borderpad=0.2)

    plt.tight_layout(pad=0.2)
    os.makedirs(os.path.dirname(out_base), exist_ok=True)
    plt.savefig(out_base + ".pdf", bbox_inches="tight", pad_inches=0.02)
    plt.savefig(out_base + ".png", bbox_inches="tight", pad_inches=0.02, dpi=450)
    plt.close(fig)



# ===================== .mat 导出 =====================
def save_mat_train_test(out_path, payload: dict):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    savemat(out_path, payload)
    print(f"[OK] Saved mat: {out_path}")


# ===================== 信道/预处理（与你训练一致） =====================
def compute_doppler_shift(v_kmh, fc_hz):
    c = 3e8
    v = v_kmh / 3.6
    return fc_hz * v / c


def apply_doppler_shift(signal_c, fd_hz, fs_hz):
    t = np.arange(signal_c.shape[-1]) / fs_hz
    return signal_c * np.exp(1j * 2 * np.pi * fd_hz * t)


def add_awgn(signal_c, snr_db):
    power = np.mean(np.abs(signal_c) ** 2) + 1e-12
    noise_power = power / (10 ** (snr_db / 10))
    noise = np.sqrt(noise_power / 2) * (np.random.randn(*signal_c.shape) + 1j * np.random.randn(*signal_c.shape))
    return signal_c + noise


def preprocess_for_pointcloud_cnn(data_real_imag, add_noise=False, snr_db=None,
                                  add_doppler=False, fd_hz=None, fs_hz=FS):
    """
    功率归一化 -> Doppler -> AWGN -> per-sample z-score
    """
    data = data_real_imag.astype(np.float32, copy=True)
    N, L, _ = data.shape
    out = np.empty_like(data, dtype=np.float32)

    for i in range(N):
        iq = data[i]
        sigc = iq[..., 0] + 1j * iq[..., 1]

        sigc = sigc / (np.sqrt(np.mean(np.abs(sigc) ** 2)) + 1e-12)

        if add_doppler and fd_hz is not None:
            sigc = apply_doppler_shift(sigc, fd_hz, fs_hz)

        if add_noise and snr_db is not None:
            sigc = add_awgn(sigc, snr_db)

        iq2 = np.stack([np.real(sigc), np.imag(sigc)], axis=-1).astype(np.float32)

        mu = iq2.mean(axis=0)
        sigma = iq2.std(axis=0)
        sigma[sigma < 1e-8] = 1.0
        out[i] = (iq2 - mu) / sigma

    return out


def parse_snr_fd_classes(folder_name: str):
    snr = None
    fd = None
    ncls = None
    m = re.search(r"SNR(-?\d+)dB", folder_name)
    if m:
        snr = int(m.group(1))
    m = re.search(r"fd(\d+)", folder_name)
    if m:
        fd = int(m.group(1))
    m = re.search(r"classes[_-](\d+)", folder_name)
    if m:
        ncls = int(m.group(1))
    return snr, fd, ncls


# ===================== WiSig blocks（与你 cyclic 逻辑一致） =====================
def load_blocks_wisig_cyclic(compact_dataset, tx_list, dates,
                             max_sig=None, equalized=0, block_size=240, y=10):
    """
    返回：
      raw_blocks: (B, G, 256, 2) 其中 G=block_size
      xfr_blocks: (B, 256, G, 2)
      y_blocks:   (B,)
    """
    raw_blocks, xfr_blocks, y_blocks = [], [], []

    try:
        eq_i = compact_dataset["equalized_list"].index(equalized)
    except ValueError:
        raise RuntimeError(f"equalized={equalized} 不在 equalized_list 中")

    for tx_idx, tx in enumerate(tx_list):
        try:
            tx_i = compact_dataset["tx_list"].index(tx)
        except ValueError:
            continue

        for date in dates:
            if date not in compact_dataset["capture_date_list"]:
                continue
            date_i = compact_dataset["capture_date_list"].index(date)

            rx_signals = []
            for rx_i in range(len(compact_dataset["rx_list"])):
                sig_data = compact_dataset["data"][tx_i][rx_i][date_i][eq_i]
                if max_sig is not None:
                    sig_data = sig_data[:max_sig]
                rx_signals.append(list(sig_data))

            num_rx = len(rx_signals)
            rx_pointer = 0
            accum_block = []

            while any(len(sig_list) > 0 for sig_list in rx_signals):
                rx_idx = rx_pointer % num_rx
                sig_list = rx_signals[rx_idx]

                if len(sig_list) > 0:
                    take_n = min(y, len(sig_list))
                    sampled = [sig_list.pop(0) for _ in range(take_n)]
                    accum_block.extend(sampled)

                rx_pointer += 1

                while len(accum_block) >= block_size:
                    block_chunk = accum_block[:block_size]
                    accum_block = accum_block[block_size:]

                    block_array = np.array(block_chunk)  # (G,256,2)
                    if block_array.ndim != 3 or block_array.shape[-1] != 2:
                        raise RuntimeError(f"Unexpected signal shape: {block_array.shape}")

                    raw_blocks.append(block_array.astype(np.float32))
                    xfr_blocks.append(block_array.transpose(1, 0, 2).astype(np.float32))  # (256,G,2)
                    y_blocks.append(tx_idx)

            accum_block = []

    if len(raw_blocks) == 0:
        raise RuntimeError("未生成任何 block：检查 dates / block_size / y / max_sig 等参数")

    raw_blocks = np.stack(raw_blocks, axis=0)
    xfr_blocks = np.stack(xfr_blocks, axis=0)
    y_blocks = np.array(y_blocks, dtype=np.int64)

    return raw_blocks, xfr_blocks, y_blocks


# ===================== 模型（自动识别 ResNet18_1D / RF1DCNN） =====================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = out + identity
        return self.relu(out)


class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)
    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)
    def extract_features(self, x):
        x = x.permute(0, 2, 1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return x
    def forward(self, x):
        feat = self.extract_features(x)
        return self.fc(feat)


class ResidualBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, 1, padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = None
        if in_channels != out_channels or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )
    def forward(self, x):
        identity = x
        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = out + identity
        return self.relu(out)


class RF1DCNN(nn.Module):
    def __init__(self, num_classes, dropout=0.3, input_length=256):
        super().__init__()
        self.layer1 = ResidualBlock1D(2, 32, kernel_size=7); self.pool1 = nn.MaxPool1d(2)
        self.layer2 = ResidualBlock1D(32, 64, kernel_size=5); self.pool2 = nn.MaxPool1d(2)
        self.layer3 = ResidualBlock1D(64, 128, kernel_size=5); self.pool3 = nn.MaxPool1d(2)
        self.layer4 = ResidualBlock1D(128, 256, kernel_size=3); self.pool4 = nn.MaxPool1d(2)

        L = input_length
        for _ in range(4):
            L = L // 2
        self.flattened_length = 256 * L

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.flattened_length, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )
    def extract_features(self, x):
        x = x.permute(0, 2, 1)
        x = self.layer1(x); x = self.pool1(x)
        x = self.layer2(x); x = self.pool2(x)
        x = self.layer3(x); x = self.pool3(x)
        x = self.layer4(x); x = self.pool4(x)
        x = torch.flatten(x, 1)
        x = self.fc[1](x)
        x = self.fc[2](x)
        return x
    def forward(self, x):
        feat = self.extract_features(x)
        x = self.fc[3](feat)
        x = self.fc[4](x)
        return x


def strip_module_prefix(state_dict):
    if any(k.startswith("module.") for k in state_dict.keys()):
        return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
    return state_dict


def find_fold_models(exp_dir):
    import glob
    paths = glob.glob(os.path.join(exp_dir, "best_model_fold*.pth"))
    if len(paths) == 0:
        paths = glob.glob(os.path.join(exp_dir, "model_fold*.pth"))
    if len(paths) == 0:
        raise RuntimeError(f"在 {exp_dir} 未找到 best_model_fold*.pth 或 model_fold*.pth")

    def fold_key(p):
        m = re.search(r"fold(\d+)", os.path.basename(p))
        return int(m.group(1)) if m else 999
    return sorted(paths, key=fold_key)


def build_model_from_state(state, input_length_hint, device):
    keys = list(state.keys())

    is_resnet18 = any(k.startswith("layer1.0.") for k in keys) and ("fc.weight" in keys)
    is_rf1dcnn = any(k.startswith("layer1.conv1.") for k in keys) and any(k.startswith("fc.1.") for k in keys)

    if is_resnet18:
        in_planes = state["conv1.weight"].shape[0]
        num_classes = state["fc.weight"].shape[0]
        model = ResNet18_1D(num_classes=num_classes, in_planes=in_planes, dropout=0.0).to(device)
        model.load_state_dict(state, strict=True)
        model.eval()
        return model

    if is_rf1dcnn:
        # 关键：RF1DCNN 的 flatten 维度依赖 input_length_hint（=你的BLOCK_SIZE）
        num_classes = state["fc.4.weight"].shape[0]
        model = RF1DCNN(num_classes=num_classes, dropout=0.3, input_length=input_length_hint).to(device)
        model.load_state_dict(state, strict=True)
        model.eval()
        return model

    raise RuntimeError("无法识别模型结构：state_dict keys不匹配 ResNet18_1D / RF1DCNN。")


class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)


def stratified_sample_indices(y, max_per_class, seed=42):
    rng = np.random.default_rng(seed)
    keep = []
    for c in np.unique(y):
        idx = np.where(y == c)[0]
        if len(idx) > max_per_class:
            idx = rng.choice(idx, size=max_per_class, replace=False)
        keep.append(idx)
    keep = np.concatenate(keep)
    rng.shuffle(keep)
    return keep


def extract_features_foldmean(model_paths, X, y, device):
    ds = SimpleDataset(X, y)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    feats_folds = []
    input_length_hint = X.shape[1]

    for mp in model_paths:
        state = strip_module_prefix(torch.load(mp, map_location=device))
        model = build_model_from_state(state, input_length_hint=input_length_hint, device=device)

        feats = []
        with torch.no_grad():
            for xb, _ in tqdm(loader, desc=f"Extract {os.path.basename(mp)}", leave=False):
                xb = xb.to(device)
                f = model.extract_features(xb).cpu().numpy()
                feats.append(f)
        feats = np.concatenate(feats, axis=0)
        feats_folds.append(feats)

    feats_folds = np.stack(feats_folds, axis=0)
    return feats_folds.mean(axis=0)


def tsne_joint_two_splits(feat_train, feat_test, n_components=2, seed=42):
    """
    对 train+test 拼接 joint-fit，确保同一坐标系
    """
    X = np.concatenate([feat_train, feat_test], axis=0).astype(np.float32)

    pca_dim_eff = min(PCA_DIM, X.shape[1])
    X_pca = PCA(n_components=pca_dim_eff, random_state=seed).fit_transform(X)

    tsne_kwargs = dict(
        n_components=n_components,
        perplexity=TSNE_PERPLEXITY,
        init="pca",
        learning_rate="auto",
        metric=TSNE_METRIC,
        random_state=seed,
        verbose=1
    )
    try:
        tsne = TSNE(**tsne_kwargs, n_iter=TSNE_ITERS)
    except TypeError:
        tsne = TSNE(**tsne_kwargs, max_iter=TSNE_ITERS)

    emb = tsne.fit_transform(X_pca)
    return emb[:feat_train.shape[0]], emb[feat_train.shape[0]:]


def main():
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(OUT_DIR, exist_ok=True)

    # common channel parameters
    snr_xfr, fd_xfr, _ = parse_snr_fd_classes(os.path.basename(EXP_XFR_DIR))
    if USE_COMMON_CHANNEL:
        snr_use = snr_xfr if snr_xfr is not None else 20
        fd_use = fd_xfr if fd_xfr is not None else int(compute_doppler_shift(VELOCITY_KMH, FC))
    else:
        snr_use = snr_xfr or 20
        fd_use = fd_xfr or int(compute_doppler_shift(VELOCITY_KMH, FC))

    print(f"[INFO] SNR_use={snr_use} dB, fd_use={fd_use} Hz")

    # load dataset
    compact_dataset = load_compact_pkl_dataset(DATASET_PATH, DATASET_NAME)
    tx_list = compact_dataset["tx_list"]
    n_tx = len(tx_list)
    if n_tx != 6:
        print(f"[WARN] 当前tx数量={n_tx}，但你希望标注TX1~TX6。将按实际n_tx生成class_names。")

    class_names = [f"TX{i+1}" for i in range(n_tx)]

    # 1) 构造 train blocks 与 test blocks（按日期划分）
    raw_tr_blk, xfr_tr_blk, y_tr_blk = load_blocks_wisig_cyclic(
        compact_dataset, tx_list, TRAIN_DATES,
        max_sig=MAX_SIG, equalized=EQUALIZED, block_size=BLOCK_SIZE, y=Y_PER_RX
    )
    raw_te_blk, xfr_te_blk, y_te_blk = load_blocks_wisig_cyclic(
        compact_dataset, tx_list, TEST_DATES,
        max_sig=MAX_SIG, equalized=EQUALIZED, block_size=BLOCK_SIZE, y=Y_PER_RX
    )

    # shapes
    Btr, G, L, _ = raw_tr_blk.shape   # RAW: (B,G,256,2)
    Bte, _, _, _ = raw_te_blk.shape
    assert xfr_tr_blk.shape == (Btr, L, G, 2)
    assert xfr_te_blk.shape == (Bte, L, G, 2)

    # 2) 组装样本级：RAW(time-IQ) 与 XFR(cross-IQ)
    # RAW samples: (B*G, 256,2)
    X_raw_tr = raw_tr_blk.reshape(-1, L, 2)
    y_raw_tr = np.repeat(y_tr_blk, G)
    X_raw_te = raw_te_blk.reshape(-1, L, 2)
    y_raw_te = np.repeat(y_te_blk, G)

    # XFR samples: (B*L, G,2) 其中 G=BLOCK_SIZE(240)
    X_xfr_tr = xfr_tr_blk.reshape(-1, G, 2)
    y_xfr_tr = np.repeat(y_tr_blk, L)
    X_xfr_te = xfr_te_blk.reshape(-1, G, 2)
    y_xfr_te = np.repeat(y_te_blk, L)

    # 3) 分层抽样控制规模（分别对 train/test）
    idx = stratified_sample_indices(y_raw_tr, MAX_SAMPLES_PER_CLASS_RAW_TRAIN, seed=SEED)
    X_raw_tr, y_raw_tr = X_raw_tr[idx], y_raw_tr[idx]
    idx = stratified_sample_indices(y_raw_te, MAX_SAMPLES_PER_CLASS_RAW_TEST, seed=SEED)
    X_raw_te, y_raw_te = X_raw_te[idx], y_raw_te[idx]

    idx = stratified_sample_indices(y_xfr_tr, MAX_SAMPLES_PER_CLASS_XFR_TRAIN, seed=SEED)
    X_xfr_tr, y_xfr_tr = X_xfr_tr[idx], y_xfr_tr[idx]
    idx = stratified_sample_indices(y_xfr_te, MAX_SAMPLES_PER_CLASS_XFR_TEST, seed=SEED)
    X_xfr_te, y_xfr_te = X_xfr_te[idx], y_xfr_te[idx]

    print(f"[INFO] RAW train/test samples: {len(y_raw_tr)}/{len(y_raw_te)}")
    print(f"[INFO] XFR train/test samples: {len(y_xfr_tr)}/{len(y_xfr_te)}")

    # 4) 注入信道 + per-sample 标准化（保持与训练一致）
    if APPLY_DOPPLER or APPLY_AWGN:
        # 为可复现：固定随机种子后再处理（AWGN用到了np.random）
        np.random.seed(SEED)
        X_raw_tr = preprocess_for_pointcloud_cnn(X_raw_tr, add_noise=APPLY_AWGN, snr_db=snr_use,
                                                 add_doppler=APPLY_DOPPLER, fd_hz=fd_use, fs_hz=FS)
        X_raw_te = preprocess_for_pointcloud_cnn(X_raw_te, add_noise=APPLY_AWGN, snr_db=snr_use,
                                                 add_doppler=APPLY_DOPPLER, fd_hz=fd_use, fs_hz=FS)

        np.random.seed(SEED + 1)
        X_xfr_tr = preprocess_for_pointcloud_cnn(X_xfr_tr, add_noise=APPLY_AWGN, snr_db=snr_use,
                                                 add_doppler=APPLY_DOPPLER, fd_hz=fd_use, fs_hz=FS)
        X_xfr_te = preprocess_for_pointcloud_cnn(X_xfr_te, add_noise=APPLY_AWGN, snr_db=snr_use,
                                                 add_doppler=APPLY_DOPPLER, fd_hz=fd_use, fs_hz=FS)

    # 5) 加载模型fold
    raw_models = find_fold_models(EXP_RAW_DIR)
    xfr_models = find_fold_models(EXP_XFR_DIR)
    print(f"[INFO] folds: RAW={len(raw_models)} XFR={len(xfr_models)}")

    # 6) 提取特征（fold-mean）
    feat_raw_tr = extract_features_foldmean(raw_models, X_raw_tr, y_raw_tr, device)
    feat_raw_te = extract_features_foldmean(raw_models, X_raw_te, y_raw_te, device)

    feat_xfr_tr = extract_features_foldmean(xfr_models, X_xfr_tr, y_xfr_tr, device)
    feat_xfr_te = extract_features_foldmean(xfr_models, X_xfr_te, y_xfr_te, device)

    # 7) 2D t-SNE：分别对 RAW 与 XFR 做 train+test joint-fit
    emb2_raw_tr, emb2_raw_te = tsne_joint_two_splits(feat_raw_tr, feat_raw_te, n_components=2, seed=SEED)
    emb2_xfr_tr, emb2_xfr_te = tsne_joint_two_splits(feat_xfr_tr, feat_xfr_te, n_components=2, seed=SEED)

    # 统一坐标范围（同一方法内 train/test 共享范围；RAW与XFR不强行共范围）
    def lim2(emb_a, emb_b):
        x_all = np.concatenate([emb_a[:, 0], emb_b[:, 0]])
        y_all = np.concatenate([emb_a[:, 1], emb_b[:, 1]])
        return (float(x_all.min()), float(x_all.max())), (float(y_all.min()), float(y_all.max()))

    xlim_raw, ylim_raw = lim2(emb2_raw_tr, emb2_raw_te)
    xlim_xfr, ylim_xfr = lim2(emb2_xfr_tr, emb2_xfr_te)

    # 8) 输出两张 overlay 图（每张图里 train/test 在一起）
    save_tsne_2d_train_test_overlay(
        emb2_raw_tr, y_raw_tr, emb2_raw_te, y_raw_te,
        class_names=class_names,
        out_base=os.path.join(OUT_DIR, "RAW_train_test_2D"),
        xlim=xlim_raw, ylim=ylim_raw
    )

    save_tsne_2d_train_test_overlay(
        emb2_xfr_tr, y_xfr_tr, emb2_xfr_te, y_xfr_te,
        class_names=class_names,
        out_base=os.path.join(OUT_DIR, "XFR_train_test_2D"),
        xlim=xlim_xfr, ylim=ylim_xfr
    )

    print("[OK] Saved 2 overlay figures (RAW/XFR).")

    # 9) 导出 .mat（MATLAB可直接画/分析）
    meta = {
        "DATASET_NAME": DATASET_NAME,
        "DATASET_PATH": DATASET_PATH,
        "EXP_RAW_DIR": EXP_RAW_DIR,
        "EXP_XFR_DIR": EXP_XFR_DIR,
        "TRAIN_DATES": TRAIN_DATES,
        "TEST_DATES": TEST_DATES,
        "EQUALIZED": int(EQUALIZED),
        "MAX_SIG": -1 if MAX_SIG is None else int(MAX_SIG),
        "BLOCK_SIZE": int(BLOCK_SIZE),
        "Y_PER_RX": int(Y_PER_RX),
        "SNR_use_dB": int(snr_use),
        "fd_use_Hz": int(fd_use),
        "FS": float(FS),
        "FC": float(FC),
        "VELOCITY_KMH": float(VELOCITY_KMH),
        "APPLY_DOPPLER": int(APPLY_DOPPLER),
        "APPLY_AWGN": int(APPLY_AWGN),
        "MAX_SAMPLES_PER_CLASS_RAW_TRAIN": int(MAX_SAMPLES_PER_CLASS_RAW_TRAIN),
        "MAX_SAMPLES_PER_CLASS_RAW_TEST": int(MAX_SAMPLES_PER_CLASS_RAW_TEST),
        "MAX_SAMPLES_PER_CLASS_XFR_TRAIN": int(MAX_SAMPLES_PER_CLASS_XFR_TRAIN),
        "MAX_SAMPLES_PER_CLASS_XFR_TEST": int(MAX_SAMPLES_PER_CLASS_XFR_TEST),
        "PCA_DIM": int(PCA_DIM),
        "TSNE_PERPLEXITY": int(TSNE_PERPLEXITY),
        "TSNE_ITERS": int(TSNE_ITERS),
        "TSNE_METRIC": TSNE_METRIC,
        "SEED": int(SEED),
    }

    out_mat = os.path.join(OUT_DIR, "WISIG_train_test_tsne_2D_all.mat")
    payload = {
        # RAW
        "emb2_raw_train": emb2_raw_tr.astype(np.float64),
        "emb2_raw_test":  emb2_raw_te.astype(np.float64),
        "y_raw_train": y_raw_tr.astype(np.int64).reshape(-1, 1),
        "y_raw_test":  y_raw_te.astype(np.int64).reshape(-1, 1),

        # XFR
        "emb2_xfr_train": emb2_xfr_tr.astype(np.float64),
        "emb2_xfr_test":  emb2_xfr_te.astype(np.float64),
        "y_xfr_train": y_xfr_tr.astype(np.int64).reshape(-1, 1),
        "y_xfr_test":  y_xfr_te.astype(np.int64).reshape(-1, 1),

        # names / meta
        "class_names": np.array(class_names, dtype=object),
        "meta": meta,
    }
    save_mat_train_test(out_mat, payload)


if __name__ == "__main__":
    main()


In [6]:
# plot_ltev_feature_tsne_sample_level_export_mat_and_figs.py
# 样本级（非block聚合）特征 t-SNE：RAW vs XFR，joint-fit 2D/3D
# 输出：4张独立图（PDF+PNG, 无title, WiSig同款学术风格） + 4份.mat + 1份总.mat
#
# 本版修改点（基于你给的脚本“直接改”）：
#   - 不画椭圆
#   - XFR：TXk 直接写在图上（白描边文字），不再单独放类别legend
#   - RAW：不在图中写TXk，改为“左下角类别legend（TX1..TXn）”
#   - 画图可二次下采样（只影响显示，不影响 t-SNE 拟合与 .mat 导出）
#   - 总 .mat 增加 class_names / txid_list / mapping（便于追溯）

import os
import re
import glob
import numpy as np
import h5py
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe

from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from scipy.io import savemat
from matplotlib.lines import Line2D


# ===================== 你需要确认/修改的参数 =====================
DATA_PATH = r"E:/rf_datasets/"  # LTE-V .mat 文件夹

EXP_XFR_DIR = r"./training_results/2026-01-07_19-01-27_LTE-V_XFR_SNR20dB_fd655_classes_9_ResNet"
EXP_RAW_DIR = r"./training_results/2025-11-27_11-25-06_LTE_time_SNR20dB_fd960_classes_9"

OUT_DIR = r"./TSNE"

GROUP_SIZE = 864
TEST_SIZE = 0.25
SPLIT_SEED = 42

# 为了公平对比，建议 True：两边用同一个 SNR/fd 注入生成同源 blocks
USE_COMMON_CHANNEL = True
FS = 5e6
APPLY_DOPPLER = True
APPLY_AWGN = True

# 抽样（样本级每类最多点数，避免t-SNE过慢）
MAX_SAMPLES_PER_CLASS_RAW = 1500
MAX_SAMPLES_PER_CLASS_XFR = 1500

BATCH_SIZE = 1024
NUM_WORKERS = 0

# t-SNE
SEED = 42
PCA_DIM = 50
TSNE_PERPLEXITY = 30
TSNE_ITERS = 1500
TSNE_METRIC = "cosine"

# ===== 绘图（WiSig同款）=====
PLOT_MAX_PER_CLASS = 600    # 仅影响“画图显示”
PLOT_MARKER = "^"
PLOT_POINT_SIZE = 14
PLOT_POINT_ALPHA = 0.90
PLOT_EDGE_LW = 0.25
# ===============================================================


# ===================== 论文风格绘图配置（WiSig模板） =====================
def set_paper_style():
    import matplotlib as mpl
    mpl.rcParams.update({
        "font.family": "serif",
        "font.size": 9,
        "axes.labelsize": 9,
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.linewidth": 0.8,
        "legend.frameon": False,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "figure.facecolor": "white",
        "axes.facecolor": "white",
    })


def _subsample(idx, max_n, rng):
    if (max_n is None) or (len(idx) <= max_n):
        return idx
    return rng.choice(idx, size=max_n, replace=False)


def _build_class_legend_handles(class_names, cmap, marker="^", edge_color="k", edge_lw=0.25, alpha=0.90):
    handles = []
    for i, name in enumerate(class_names):
        col = cmap(i % 10)
        h = Line2D(
            [0], [0],
            marker=marker,
            linestyle="None",
            markerfacecolor=col,
            markeredgecolor=edge_color,
            markeredgewidth=edge_lw,
            alpha=alpha,
            markersize=6,
            label=name
        )
        handles.append(h)
    return handles


def save_tsne_2d_single_wisig_style(
    emb2, y, class_names, out_base,
    xlim=None, ylim=None,
    plot_max_per_class=600,
    marker="^", point_size=14, alpha=0.90,
    edge_lw=0.25, edge_color="k",
    label_mode="text",            # "text"=TXk写在图上；"legend"=不写TXk，改左下角legend
    legend_loc="lower left"
):
    """
    WiSig 风格单图（无椭圆）：
      - tab10 配色（按类循环）
      - 三角形点 + 黑描边
      - label_mode="text"：TXk 文字直接写在图上（白描边）
      - label_mode="legend"：不写文字，左下角放类别legend
    """
    set_paper_style()
    os.makedirs(os.path.dirname(out_base), exist_ok=True)
    rng = np.random.default_rng(SEED)

    num_classes = int(y.max()) + 1
    cmap = plt.cm.get_cmap("tab10", 10)

    fig = plt.figure(figsize=(3.6, 3.1), dpi=450)
    ax = fig.add_subplot(1, 1, 1)

    for c in range(num_classes):
        idx_all = np.where(y == c)[0]
        if idx_all.size == 0:
            continue

        col = cmap(c % 10)
        idx_plot = _subsample(idx_all, plot_max_per_class, rng)

        ax.scatter(
            emb2[idx_plot, 0], emb2[idx_plot, 1],
            s=point_size, c=[col], alpha=alpha,
            marker=marker, linewidths=edge_lw, edgecolors=edge_color,
            zorder=3, rasterized=True
        )

        if label_mode == "text":
            cx, cy = np.median(emb2[idx_all], axis=0)
            txt = ax.text(
                cx, cy, class_names[c],
                fontsize=9, color="black",
                ha="center", va="center", zorder=4
            )
            txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="white")])

    # RAW：左下角 legend
    if label_mode == "legend":
        handles = _build_class_legend_handles(
            class_names=class_names,
            cmap=cmap,
            marker=marker,
            edge_color=edge_color,
            edge_lw=edge_lw,
            alpha=alpha
        )
        ax.legend(
            handles=handles,
            loc=legend_loc,
            fontsize=7,
            ncol=1,
            handletextpad=0.4,
            borderpad=0.2,
            labelspacing=0.25
        )

    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    plt.tight_layout(pad=0.2)
    plt.savefig(out_base + ".pdf", bbox_inches="tight", pad_inches=0.02)
    plt.savefig(out_base + ".png", bbox_inches="tight", pad_inches=0.02, dpi=450)
    plt.close(fig)


def save_tsne_3d_single_wisig_style(
    emb3, y, class_names, out_base,
    xlim=None, ylim=None, zlim=None,
    view=(18, -60),
    plot_max_per_class=600,
    marker="^", point_size=14, alpha=0.90,
    edge_lw=0.25, edge_color="k",
    grid_alpha=0.25,
    label_mode="text",            # "text" or "legend"
    legend_loc="lower left"
):
    """
    WiSig 风格 3D 单图（无椭圆）：
      - tab10 配色
      - 三角形点 + 黑描边
      - label_mode="text"：TXk 写在图上（白描边）
      - label_mode="legend"：不写文字，左下角 legend
    """
    set_paper_style()
    os.makedirs(os.path.dirname(out_base), exist_ok=True)
    rng = np.random.default_rng(SEED)

    num_classes = int(y.max()) + 1
    cmap = plt.cm.get_cmap("tab10", 10)

    fig = plt.figure(figsize=(3.6, 3.1), dpi=450)
    ax = fig.add_subplot(1, 1, 1, projection="3d")

    for c in range(num_classes):
        idx_all = np.where(y == c)[0]
        if idx_all.size == 0:
            continue
        col = cmap(c % 10)
        idx_plot = _subsample(idx_all, plot_max_per_class, rng)

        ax.scatter(
            emb3[idx_plot, 0], emb3[idx_plot, 1], emb3[idx_plot, 2],
            s=point_size, c=[col], alpha=alpha,
            marker=marker, linewidths=edge_lw, edgecolors=edge_color
        )

        if label_mode == "text":
            cx, cy, cz = np.median(emb3[idx_all], axis=0)
            txt = ax.text(
                cx, cy, cz, class_names[c],
                fontsize=9, color="black",
                ha="center", va="center"
            )
            txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="white")])

    if label_mode == "legend":
        handles = _build_class_legend_handles(
            class_names=class_names,
            cmap=cmap,
            marker=marker,
            edge_color=edge_color,
            edge_lw=edge_lw,
            alpha=alpha
        )
        ax.legend(
            handles=handles,
            loc=legend_loc,
            fontsize=7,
            ncol=1,
            handletextpad=0.4,
            borderpad=0.2,
            labelspacing=0.25
        )

    ax.set_xlabel("t-SNE 1", labelpad=2)
    ax.set_ylabel("t-SNE 2", labelpad=2)
    ax.set_zlabel("t-SNE 3", labelpad=2)

    ax.view_init(elev=view[0], azim=view[1])
    ax.grid(True, alpha=grid_alpha)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    if zlim is not None:
        ax.set_zlim(zlim)

    plt.tight_layout(pad=0.2)
    plt.savefig(out_base + ".pdf", bbox_inches="tight", pad_inches=0.02)
    plt.savefig(out_base + ".png", bbox_inches="tight", pad_inches=0.02, dpi=450)
    plt.close(fig)


# ===================== .mat 导出 =====================
def save_embedding_mat_2d(emb2, y, out_path, meta=None):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    mdict = {"emb": emb2.astype(np.float64), "label": y.astype(np.int64).reshape(-1, 1)}
    if meta is not None:
        mdict["meta"] = meta
    savemat(out_path, mdict)
    print(f"[OK] Saved mat: {out_path}")


def save_embedding_mat_3d(emb3, y, out_path, meta=None):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    mdict = {"emb": emb3.astype(np.float64), "label": y.astype(np.int64).reshape(-1, 1)}
    if meta is not None:
        mdict["meta"] = meta
    savemat(out_path, mdict)
    print(f"[OK] Saved mat: {out_path}")


# ===================== 信号处理 =====================
def apply_doppler_shift(signal_c, fd_hz, fs_hz):
    t = np.arange(signal_c.shape[-1]) / fs_hz
    return signal_c * np.exp(1j * 2 * np.pi * fd_hz * t)


def add_awgn(signal_c, snr_db):
    power = np.mean(np.abs(signal_c) ** 2) + 1e-12
    noise_power = power / (10 ** (snr_db / 10))
    noise = np.sqrt(noise_power / 2) * (np.random.randn(*signal_c.shape) + 1j * np.random.randn(*signal_c.shape))
    return signal_c + noise


def parse_snr_fd_classes(folder_name: str):
    snr = None
    fd = None
    ncls = None
    m = re.search(r"SNR(-?\d+)dB", folder_name)
    if m:
        snr = int(m.group(1))
    m = re.search(r"fd(\d+)", folder_name)
    if m:
        fd = int(m.group(1))
    m = re.search(r"classes[_-](\d+)", folder_name)
    if m:
        ncls = int(m.group(1))
    return snr, fd, ncls


# ===================== LTE-V blocks =====================
def load_blocks_ltev(mat_folder,
                     group_size=864,
                     apply_doppler=True,
                     fd_hz=655,
                     apply_awgn=True,
                     snr_db=20,
                     fs=5e6):
    mat_files = glob.glob(os.path.join(mat_folder, "*.mat"))
    if len(mat_files) == 0:
        raise RuntimeError(f"在 {mat_folder} 未找到 .mat 文件")

    X_files, y_files, label_set = [], [], set()

    print(f"[INFO] Found {len(mat_files)} .mat files")
    for file in tqdm(mat_files, desc="Reading LTE-V .mat"):
        with h5py.File(file, "r") as f:
            rfDataset = f["rfDataset"]
            dmrs = rfDataset["dmrs"][:]
            dmrs_complex = dmrs["real"] + 1j * dmrs["imag"]
            txID_uint16 = rfDataset["txID"][:].flatten()
            tx_id = "".join(chr(c) for c in txID_uint16 if c != 0)

            processed = []
            for i in range(dmrs_complex.shape[0]):
                sig = dmrs_complex[i, :].astype(np.complex64)
                sig = sig / (np.sqrt(np.mean(np.abs(sig) ** 2)) + 1e-12)

                if apply_doppler:
                    sig = apply_doppler_shift(sig, fd_hz, fs)

                if apply_awgn:
                    sig = add_awgn(sig, snr_db)

                iq = np.stack([sig.real, sig.imag], axis=-1).astype(np.float32)
                processed.append(iq)

            processed = np.array(processed, dtype=np.float32)
            X_files.append(processed)
            y_files.append(tx_id)
            label_set.add(tx_id)

    label_list = sorted(list(label_set))
    label_to_idx = {lb: i for i, lb in enumerate(label_list)}

    raw_blocks_list = []
    xfr_blocks_list = []
    y_blocks_list = []

    for label in label_list:
        files_idx = [i for i, y in enumerate(y_files) if y == label]
        num_files = len(files_idx)
        if num_files == 0:
            continue

        samples_per_file = group_size // num_files
        if samples_per_file == 0:
            continue

        min_samples = min([X_files[i].shape[0] for i in files_idx])
        max_groups = min_samples // samples_per_file
        if max_groups == 0:
            continue

        for gi in range(max_groups):
            pieces = []
            for fi in files_idx:
                s = gi * samples_per_file
                e = s + samples_per_file
                pieces.append(X_files[fi][s:e])
            big_block = np.concatenate(pieces, axis=0)  # (G,L,2)

            raw_blocks_list.append(big_block)  # (G,L,2)
            xfr_blocks_list.append(np.transpose(big_block, (1, 0, 2)))  # (L,G,2)
            y_blocks_list.append(label_to_idx[label])

    if len(raw_blocks_list) == 0:
        raise RuntimeError("未生成任何 block，请检查 group_size 或数据组织。")

    raw_blocks = np.stack(raw_blocks_list, axis=0)  # (B,G,L,2)
    xfr_blocks = np.stack(xfr_blocks_list, axis=0)  # (B,L,G,2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    print(f"[INFO] blocks={raw_blocks.shape[0]}, group_size={raw_blocks.shape[1]}, sample_len={raw_blocks.shape[2]}, classes={len(label_list)}")
    return raw_blocks, xfr_blocks, y_blocks, label_list


# ===================== ResNet18_1D extract_features =====================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = out + identity
        return self.relu(out)


class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def extract_features(self, x):
        x = x.permute(0, 2, 1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return x

    def forward(self, x):
        feat = self.extract_features(x)
        return self.fc(feat)


def strip_module_prefix(state_dict):
    if any(k.startswith("module.") for k in state_dict.keys()):
        return {k.replace("module.", "", 1): v for k, v in state_dict.items()}
    return state_dict


def find_fold_models(exp_dir):
    paths = glob.glob(os.path.join(exp_dir, "best_model_fold*.pth"))
    if len(paths) == 0:
        paths = glob.glob(os.path.join(exp_dir, "model_fold*.pth"))
    if len(paths) == 0:
        raise RuntimeError(f"在 {exp_dir} 未找到 best_model_fold*.pth 或 model_fold*.pth")

    def fold_key(p):
        m = re.search(r"fold(\d+)", os.path.basename(p))
        return int(m.group(1)) if m else 999

    return sorted(paths, key=fold_key)


# ===================== 样本级 dataset =====================
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.long)


def stratified_sample_indices(y, max_per_class, seed=42):
    rng = np.random.default_rng(seed)
    keep = []
    for c in np.unique(y):
        idx = np.where(y == c)[0]
        if len(idx) > max_per_class:
            idx = rng.choice(idx, size=max_per_class, replace=False)
        keep.append(idx)
    keep = np.concatenate(keep)
    rng.shuffle(keep)
    return keep


def extract_features_foldmean(model_paths, X, y, device):
    ds = SimpleDataset(X, y)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    feats_folds = []
    for mp in model_paths:
        state = strip_module_prefix(torch.load(mp, map_location=device))
        in_planes = state["conv1.weight"].shape[0]
        num_classes = state["fc.weight"].shape[0]

        model = ResNet18_1D(num_classes=num_classes, in_planes=in_planes, dropout=0.0).to(device)
        model.load_state_dict(state, strict=True)
        model.eval()

        feats = []
        with torch.no_grad():
            for xb, _ in tqdm(loader, desc=f"Extract {os.path.basename(mp)}", leave=False):
                xb = xb.to(device)
                f = model.extract_features(xb).cpu().numpy()
                feats.append(f)
        feats = np.concatenate(feats, axis=0)
        feats_folds.append(feats)

    feats_folds = np.stack(feats_folds, axis=0)
    return feats_folds.mean(axis=0)


def tsne_joint(feat_a, feat_b, n_components, seed=42):
    X = np.concatenate([feat_a, feat_b], axis=0).astype(np.float32)

    pca_dim_eff = min(PCA_DIM, X.shape[1])
    X_pca = PCA(n_components=pca_dim_eff, random_state=seed).fit_transform(X)

    tsne_kwargs = dict(
        n_components=n_components,
        perplexity=TSNE_PERPLEXITY,
        init="pca",
        learning_rate="auto",
        metric=TSNE_METRIC,
        random_state=seed,
        verbose=1
    )
    try:
        tsne = TSNE(**tsne_kwargs, n_iter=TSNE_ITERS)
    except TypeError:
        tsne = TSNE(**tsne_kwargs, max_iter=TSNE_ITERS)

    emb = tsne.fit_transform(X_pca)
    return emb[:feat_a.shape[0]], emb[feat_a.shape[0]:]


def main():
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(OUT_DIR, exist_ok=True)

    xfr_name = os.path.basename(EXP_XFR_DIR)
    raw_name = os.path.basename(EXP_RAW_DIR)

    snr_xfr, fd_xfr, _ = parse_snr_fd_classes(xfr_name)
    snr_raw, fd_raw, _ = parse_snr_fd_classes(raw_name)

    if USE_COMMON_CHANNEL:
        snr_use = snr_xfr if snr_xfr is not None else 20
        fd_use = fd_xfr if fd_xfr is not None else 655
        print(f"[INFO] Common channel: SNR={snr_use} dB, fd={fd_use} Hz")
    else:
        snr_use = snr_xfr or 20
        fd_use = fd_xfr or 655

    raw_blocks, xfr_blocks, y_blocks, txid_list = load_blocks_ltev(
        DATA_PATH,
        group_size=GROUP_SIZE,
        apply_doppler=APPLY_DOPPLER,
        fd_hz=fd_use,
        apply_awgn=APPLY_AWGN,
        snr_db=snr_use,
        fs=FS
    )

    num_blocks = raw_blocks.shape[0]
    sample_len = raw_blocks.shape[2]

    block_idx = np.arange(num_blocks)
    _, test_idx, _, _ = train_test_split(
        block_idx, y_blocks, test_size=TEST_SIZE, stratify=y_blocks, random_state=SPLIT_SEED
    )
    sel_blocks = test_idx
    y_sel_blocks = y_blocks[sel_blocks]

    raw_sel = raw_blocks[sel_blocks]  # (nb,G,L,2)
    X_raw = raw_sel.reshape(-1, sample_len, 2)
    y_raw = np.repeat(y_sel_blocks, raw_sel.shape[1])

    xfr_sel = xfr_blocks[sel_blocks]  # (nb,L,G,2)
    X_xfr = xfr_sel.reshape(-1, raw_sel.shape[1], 2)
    y_xfr = np.repeat(y_sel_blocks, xfr_sel.shape[1])

    idx_raw = stratified_sample_indices(y_raw, MAX_SAMPLES_PER_CLASS_RAW, seed=SEED)
    idx_xfr = stratified_sample_indices(y_xfr, MAX_SAMPLES_PER_CLASS_XFR, seed=SEED)
    X_raw, y_raw = X_raw[idx_raw], y_raw[idx_raw]
    X_xfr, y_xfr = X_xfr[idx_xfr], y_xfr[idx_xfr]
    print(f"[INFO] sample-level selected: RAW={len(y_raw)}, XFR={len(y_xfr)}")

    raw_models = find_fold_models(EXP_RAW_DIR)
    xfr_models = find_fold_models(EXP_XFR_DIR)
    print(f"[INFO] RAW folds: {len(raw_models)}, XFR folds: {len(xfr_models)}")

    feat_raw = extract_features_foldmean(raw_models, X_raw, y_raw, device)
    feat_xfr = extract_features_foldmean(xfr_models, X_xfr, y_xfr, device)

    emb2_raw, emb2_xfr = tsne_joint(feat_raw, feat_xfr, n_components=2, seed=SEED)
    emb3_raw, emb3_xfr = tsne_joint(feat_raw, feat_xfr, n_components=3, seed=SEED)

    x2_all = np.concatenate([emb2_raw[:, 0], emb2_xfr[:, 0]])
    y2_all = np.concatenate([emb2_raw[:, 1], emb2_xfr[:, 1]])
    xlim2 = (float(x2_all.min()), float(x2_all.max()))
    ylim2 = (float(y2_all.min()), float(y2_all.max()))

    x3_all = np.concatenate([emb3_raw[:, 0], emb3_xfr[:, 0]])
    y3_all = np.concatenate([emb3_raw[:, 1], emb3_xfr[:, 1]])
    z3_all = np.concatenate([emb3_raw[:, 2], emb3_xfr[:, 2]])
    xlim3 = (float(x3_all.min()), float(x3_all.max()))
    ylim3 = (float(y3_all.min()), float(y3_all.max()))
    zlim3 = (float(z3_all.min()), float(z3_all.max()))

    n_cls = int(max(y_raw.max(), y_xfr.max())) + 1
    class_names = [f"TX{i+1}" for i in range(n_cls)]
    mapping = [(class_names[i], txid_list[i] if i < len(txid_list) else "") for i in range(n_cls)]
    print("[INFO] Class mapping (TXk -> tx_id):")
    for a, b in mapping:
        print(f"  {a} -> {b}")

    # ====== 关键：RAW 用左下角 legend；XFR 用图上 TXk 文本 ======
    save_tsne_2d_single_wisig_style(
        emb2_raw, y_raw, class_names,
        os.path.join(OUT_DIR, "RAW_2D"),
        xlim=xlim2, ylim=ylim2,
        plot_max_per_class=PLOT_MAX_PER_CLASS,
        marker=PLOT_MARKER, point_size=PLOT_POINT_SIZE, alpha=PLOT_POINT_ALPHA,
        edge_lw=PLOT_EDGE_LW, edge_color="k",
        label_mode="legend",
        legend_loc="lower right"
    )
    save_tsne_2d_single_wisig_style(
        emb2_xfr, y_xfr, class_names,
        os.path.join(OUT_DIR, "XFR_2D"),
        xlim=xlim2, ylim=ylim2,
        plot_max_per_class=PLOT_MAX_PER_CLASS,
        marker=PLOT_MARKER, point_size=PLOT_POINT_SIZE, alpha=PLOT_POINT_ALPHA,
        edge_lw=PLOT_EDGE_LW, edge_color="k",
        label_mode="text"
    )

    view = (18, -60)
    save_tsne_3d_single_wisig_style(
        emb3_raw, y_raw, class_names,
        os.path.join(OUT_DIR, "RAW_3D"),
        xlim=xlim3, ylim=ylim3, zlim=zlim3, view=view,
        plot_max_per_class=PLOT_MAX_PER_CLASS,
        marker=PLOT_MARKER, point_size=PLOT_POINT_SIZE, alpha=PLOT_POINT_ALPHA,
        edge_lw=PLOT_EDGE_LW, edge_color="k",
        label_mode="legend",
        legend_loc="lower left"
    )
    save_tsne_3d_single_wisig_style(
        emb3_xfr, y_xfr, class_names,
        os.path.join(OUT_DIR, "XFR_3D"),
        xlim=xlim3, ylim=ylim3, zlim=zlim3, view=view,
        plot_max_per_class=PLOT_MAX_PER_CLASS,
        marker=PLOT_MARKER, point_size=PLOT_POINT_SIZE, alpha=PLOT_POINT_ALPHA,
        edge_lw=PLOT_EDGE_LW, edge_color="k",
        label_mode="text"
    )

    print("[OK] Saved 4 figures (RAW=legend lower-left, XFR=text on-plot).")

    meta = {
        "DATA_PATH": DATA_PATH,
        "EXP_RAW_DIR": EXP_RAW_DIR,
        "EXP_XFR_DIR": EXP_XFR_DIR,
        "SNR_use_dB": int(snr_use),
        "fd_use_Hz": int(fd_use),
        "USE_COMMON_CHANNEL": int(USE_COMMON_CHANNEL),
        "GROUP_SIZE": int(GROUP_SIZE),
        "TEST_SIZE": float(TEST_SIZE),
        "SPLIT_SEED": int(SPLIT_SEED),
        "MAX_SAMPLES_PER_CLASS_RAW": int(MAX_SAMPLES_PER_CLASS_RAW),
        "MAX_SAMPLES_PER_CLASS_XFR": int(MAX_SAMPLES_PER_CLASS_XFR),
        "PCA_DIM": int(PCA_DIM),
        "TSNE_PERPLEXITY": int(TSNE_PERPLEXITY),
        "TSNE_ITERS": int(TSNE_ITERS),
        "TSNE_METRIC": TSNE_METRIC,
        "SEED": int(SEED),
        "PLOT_MAX_PER_CLASS": int(PLOT_MAX_PER_CLASS),
        "PLOT_MARKER": PLOT_MARKER,
        "PLOT_POINT_SIZE": float(PLOT_POINT_SIZE),
        "PLOT_POINT_ALPHA": float(PLOT_POINT_ALPHA),
        "PLOT_EDGE_LW": float(PLOT_EDGE_LW),
        "RAW_label_mode": "legend_lower_left",
        "XFR_label_mode": "text_on_plot",
    }

    save_embedding_mat_2d(emb2_raw, y_raw, os.path.join(OUT_DIR, "RAW_2D_tsne.mat"), meta=meta)
    save_embedding_mat_2d(emb2_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_2D_tsne.mat"), meta=meta)
    save_embedding_mat_3d(emb3_raw, y_raw, os.path.join(OUT_DIR, "RAW_3D_tsne.mat"), meta=meta)
    save_embedding_mat_3d(emb3_xfr, y_xfr, os.path.join(OUT_DIR, "XFR_3D_tsne.mat"), meta=meta)

    all_mat_path = os.path.join(OUT_DIR, "LTEV_tsne_all.mat")
    savemat(all_mat_path, {
        "emb2_raw": emb2_raw.astype(np.float64),
        "emb2_xfr": emb2_xfr.astype(np.float64),
        "emb3_raw": emb3_raw.astype(np.float64),
        "emb3_xfr": emb3_xfr.astype(np.float64),
        "y_raw": y_raw.astype(np.int64).reshape(-1, 1),
        "y_xfr": y_xfr.astype(np.int64).reshape(-1, 1),
        "class_names": np.array(class_names, dtype=object),
        "txid_list": np.array(txid_list, dtype=object),
        "class_mapping_TX_to_txid": np.array([f"{a} -> {b}" for a, b in mapping], dtype=object),
        "meta": meta
    })
    print(f"[OK] Saved mat: {all_mat_path}")


if __name__ == "__main__":
    main()


[INFO] Common channel: SNR=20 dB, fd=655 Hz
[INFO] Found 72 .mat files


Reading LTE-V .mat: 100%|██████████| 72/72 [00:10<00:00,  6.78it/s]


[INFO] blocks=243, group_size=864, sample_len=288, classes=9
[INFO] sample-level selected: RAW=13500, XFR=13500
[INFO] RAW folds: 5, XFR folds: 5




[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 27000 samples in 0.001s...
[t-SNE] Computed neighbors for 27000 samples in 4.288s...
[t-SNE] Computed conditional probabilities for sample 1000 / 27000
[t-SNE] Computed conditional probabilities for sample 2000 / 27000
[t-SNE] Computed conditional probabilities for sample 3000 / 27000
[t-SNE] Computed conditional probabilities for sample 4000 / 27000
[t-SNE] Computed conditional probabilities for sample 5000 / 27000
[t-SNE] Computed conditional probabilities for sample 6000 / 27000
[t-SNE] Computed conditional probabilities for sample 7000 / 27000
[t-SNE] Computed conditional probabilities for sample 8000 / 27000
[t-SNE] Computed conditional probabilities for sample 9000 / 27000
[t-SNE] Computed conditional probabilities for sample 10000 / 27000
[t-SNE] Computed conditional probabilities for sample 11000 / 27000
[t-SNE] Computed conditional probabilities for sample 12000 / 27000
[t-SNE] Computed conditional probabilities for sam



[t-SNE] Computed neighbors for 27000 samples in 4.190s...
[t-SNE] Computed conditional probabilities for sample 1000 / 27000
[t-SNE] Computed conditional probabilities for sample 2000 / 27000
[t-SNE] Computed conditional probabilities for sample 3000 / 27000
[t-SNE] Computed conditional probabilities for sample 4000 / 27000
[t-SNE] Computed conditional probabilities for sample 5000 / 27000
[t-SNE] Computed conditional probabilities for sample 6000 / 27000
[t-SNE] Computed conditional probabilities for sample 7000 / 27000
[t-SNE] Computed conditional probabilities for sample 8000 / 27000
[t-SNE] Computed conditional probabilities for sample 9000 / 27000
[t-SNE] Computed conditional probabilities for sample 10000 / 27000
[t-SNE] Computed conditional probabilities for sample 11000 / 27000
[t-SNE] Computed conditional probabilities for sample 12000 / 27000
[t-SNE] Computed conditional probabilities for sample 13000 / 27000
[t-SNE] Computed conditional probabilities for sample 14000 / 27000

  cmap = plt.cm.get_cmap("tab10", 10)
  cmap = plt.cm.get_cmap("tab10", 10)


[OK] Saved 4 figures (RAW=legend lower-left, XFR=text on-plot).
[OK] Saved mat: ./TSNE\RAW_2D_tsne.mat
[OK] Saved mat: ./TSNE\XFR_2D_tsne.mat
[OK] Saved mat: ./TSNE\RAW_3D_tsne.mat
[OK] Saved mat: ./TSNE\XFR_3D_tsne.mat
[OK] Saved mat: ./TSNE\LTEV_tsne_all.mat
