In [3]:
# ==========================================================
# LTE-V Cross-Domain XFR + Ablations + Backbone Plug-and-Play
# Fixed SNR = 20 dB, one run outputs all results
#
# Methods:
#   1) XFR_BASE           : cross-domain block -> transpose (L, m, 2)
#   2) SHUFFLE_IN_BLOCK   : shuffle m (frame axis) within cross-domain block, then transpose
#   3) CONCAT_M_FRAMES    : concatenate m frames along time -> (1, m*L, 2)
#   4) MEAN_OVER_FRAMES   : mean over m frames pointwise -> (1, L, 2)
#
# Backbones:
#   - ResNet18_1D
#   - CNN1D
#   - TCN1D
#
# Outputs per (method, backbone):
#   - row-level + block-level accuracy (val/test)
#   - epochs used, train epoch time range
#   - params, model file size
#   - inference time per row sample + per block decision
# ==========================================================

import os
import glob
import time
import json
import random
import hashlib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import h5py

# ================= 参数设置 =================
data_path = "E:/rf_datasets_IQ_raw/"  # 改成你的数据文件夹
fs = 5e6
fc = 5.9e9
v_kmh = 120
apply_doppler = True
apply_awgn = True

# cross-domain block 参数
group_size = 256      # m
SNR_FIXED = 20

# 训练超参数
base_batch_size = 64
num_epochs = 300
learning_rate = 1e-4
weight_decay = 1e-3
dropout = 0.5
patience = 5
n_splits = 5
test_size = 0.25
seed = 42

# 训练集每 epoch 每 block 抽 K 行（仅对 sample_len>1 生效；对 CONCAT/MEAN 自动=1）
train_rows_per_block = 256

# 输出目录
out_root = os.path.join(os.getcwd(), "training_results")

# 方法 & backbone 列表
METHODS = [
    ("XFR_BASE", "XFR_BASE"),
    ("SHUFFLE_IN_BLOCK", "SHUFFLE_IN_BLOCK"),
    ("CONCAT_M_FRAMES", "CONCAT"),
    ("MEAN_OVER_FRAMES", "MEAN"),
]

BACKBONES = [
    ("ResNet18_1D", "resnet"),
    ("CNN1D", "cnn"),
    ("TCN1D", "tcn"),
]


# ================= 通用工具 =================
def seed_everything(seed_: int = 42):
    random.seed(seed_)
    np.random.seed(seed_)
    torch.manual_seed(seed_)
    torch.cuda.manual_seed_all(seed_)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def stable_int_hash(s: str) -> int:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def auto_batch_size_for_T(T: int, base: int = 64) -> int:
    # 保守：CONCAT 时 T=m*L 可能很大，自动降 batch 防 OOM
    if T >= 50000:
        return max(1, base // 32)
    if T >= 20000:
        return max(1, base // 16)
    if T >= 10000:
        return max(1, base // 8)
    if T >= 5000:
        return max(1, base // 4)
    if T >= 2500:
        return max(1, base // 2)
    return base

def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def sizeof_state_dict_bytes_est(model: nn.Module) -> int:
    # 仅按参数 FP32 估算（不含 optimizer 等）
    return count_params(model) * 4

def cuda_sync_if_needed(device):
    if device.type == "cuda":
        torch.cuda.synchronize()

def benchmark_inference(model: nn.Module, x: torch.Tensor, device, warmup=20, iters=100):
    """
    x: (B, T, 2) on CPU or GPU; will move to device once.
    Returns: mean_ms, p50_ms, p95_ms
    """
    model.eval()
    x = x.to(device)
    # warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)
        cuda_sync_if_needed(device)

    times = []
    with torch.no_grad():
        for _ in range(iters):
            t0 = time.perf_counter()
            _ = model(x)
            cuda_sync_if_needed(device)
            t1 = time.perf_counter()
            times.append((t1 - t0) * 1000.0)
    times = np.array(times, dtype=np.float64)
    return float(times.mean()), float(np.percentile(times, 50)), float(np.percentile(times, 95))

def plot_confusion_matrix(cm, save_path, title="Confusion Matrix"):
    plt.figure(figsize=(8, 6))
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.rcParams["axes.unicode_minus"] = False
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(title)
    plt.ylabel("Reference")
    plt.xlabel("Predicted")
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()

def moving_average(x, w=5):
    x = np.array(x, dtype=np.float64)
    if len(x) == 0:
        return np.array([])
    w = max(1, min(int(w), len(x)))
    return np.convolve(x, np.ones(w), "valid") / w

def plot_training_curves(fold_results, save_folder):
    plt.figure(figsize=(12, 5))
    for i, res in enumerate(fold_results):
        plt.plot(moving_average(res["train_loss"]), label=f"Fold{i+1} Train Loss")
        plt.plot(moving_average(res["val_loss"]), label=f"Fold{i+1} Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curves (moving avg)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_folder, "loss_curves.png"), dpi=200)
    plt.close()


# ================= 信道处理 =================
def compute_doppler_shift(v_kmh_, fc_hz):
    c = 3e8
    v = v_kmh_ / 3.6
    return (v / c) * fc_hz

def apply_doppler_shift(signal, fd, fs_):
    t = np.arange(signal.shape[-1], dtype=np.float64) / fs_
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    sig_power = np.mean(np.abs(signal) ** 2)
    noise_power = sig_power / (10 ** (snr_db / 10))
    noise = np.sqrt(noise_power / 2) * (np.random.randn(*signal.shape) + 1j * np.random.randn(*signal.shape))
    return signal + noise

def power_normalize(sig):
    return sig / (np.sqrt(np.mean(np.abs(sig) ** 2)) + 1e-12)


# ================= H5 读取兼容 =================
def read_dmrs_complex(rfDataset):
    dmrs_obj = rfDataset["dmrs"]
    if isinstance(dmrs_obj, h5py.Dataset):
        arr = dmrs_obj[:]
        if hasattr(arr.dtype, "names") and arr.dtype.names is not None and ("real" in arr.dtype.names) and ("imag" in arr.dtype.names):
            return arr["real"] + 1j * arr["imag"]
        raise RuntimeError("dmrs dataset 不是预期的 compound(real/imag) 格式。")
    else:
        real = dmrs_obj["real"][:]
        imag = dmrs_obj["imag"][:]
        return real + 1j * imag

def read_txid_str(rfDataset):
    txID_uint16 = rfDataset["txID"][:].flatten()
    chars = []
    for c in txID_uint16:
        ci = int(c)
        if ci != 0:
            chars.append(chr(ci))
    return "".join(chars)


# ================= Cross-domain block 构造 =================
def load_and_preprocess_cross_domain_blocks(
    mat_folder,
    group_size_=256,
    apply_doppler_=False,
    target_velocity_kmh_=120,
    apply_awgn_=False,
    snr_db_=20,
    fs_=5e6,
    fc_=5.9e9,
    method_key="XFR_BASE",   # XFR_BASE | SHUFFLE_IN_BLOCK | CONCAT | MEAN
    seed_=42
):
    """
    返回:
      X_blocks: (num_blocks, sample_len, time_len, 2)
      y_blocks: (num_blocks,)
      label_to_idx: dict
      meta: dict
    cross-domain 逻辑：每个 TX(label) 下，把其多个文件视为不同 domain，
    每个 block 从每个文件抽 samples_per_file 行拼接成 group_size 行，再进行表示变换。
    """
    mat_files = sorted(glob.glob(os.path.join(mat_folder, "*.mat")))
    if len(mat_files) == 0:
        raise RuntimeError(f"未找到 .mat 文件: {mat_folder}")

    fd = compute_doppler_shift(target_velocity_kmh_, fc_)
    print(f"[INFO] Cross-domain loader | files={len(mat_files)} | fd={fd:.2f} Hz | method={method_key}")

    # 先按文件读出 processed_signals，并按 tx 分组
    # tx_to_signals[tx] = list of np.ndarray, each shape (num_frames_file, L, 2)
    tx_to_signals = {}
    skipped = 0
    for file in tqdm(mat_files, desc="读取数据"):
        try:
            with h5py.File(file, "r") as f:
                if "rfDataset" not in f:
                    skipped += 1
                    continue
                rfDataset = f["rfDataset"]
                if ("dmrs" not in rfDataset) or ("txID" not in rfDataset):
                    skipped += 1
                    continue

                tx_id = read_txid_str(rfDataset)
                if tx_id == "":
                    skipped += 1
                    continue

                dmrs_complex = read_dmrs_complex(rfDataset)  # (N, L)
        except Exception:
            skipped += 1
            continue

        processed_signals = []
        for i in range(dmrs_complex.shape[0]):
            sig = dmrs_complex[i, :]
            sig = power_normalize(sig)
            if apply_doppler_:
                sig = apply_doppler_shift(sig, fd, fs_)
            if apply_awgn_:
                sig = add_awgn(sig, snr_db_)
            iq = np.stack((sig.real, sig.imag), axis=-1)  # (L,2)
            processed_signals.append(iq)

        processed_signals = np.asarray(processed_signals, dtype=np.float32)  # (N,L,2)
        tx_to_signals.setdefault(tx_id, []).append(processed_signals)

    if len(tx_to_signals) == 0:
        raise RuntimeError("未能从数据集中读取到任何 tx。")

    label_list = sorted(list(tx_to_signals.keys()))
    label_to_idx = {lab: i for i, lab in enumerate(label_list)}

    X_blocks_list = []
    y_blocks_list = []

    # 每个 tx 构造 blocks
    for tx_id in label_list:
        files_signals = tx_to_signals[tx_id]
        num_files = len(files_signals)
        if num_files == 0:
            continue
        samples_per_file = group_size_ // num_files
        if samples_per_file <= 0:
            print(f"[WARN] TX={tx_id}: num_files={num_files} > group_size={group_size_}, samples_per_file=0, skip")
            continue

        min_samples = min(arr.shape[0] for arr in files_signals)
        max_groups = min_samples // samples_per_file
        if max_groups == 0:
            print(f"[WARN] TX={tx_id}: min_samples={min_samples}, samples_per_file={samples_per_file}, max_groups=0, skip")
            continue

        for g in range(max_groups):
            pieces = []
            for arr in files_signals:
                s = g * samples_per_file
                e = s + samples_per_file
                pieces.append(arr[s:e])  # (samples_per_file, L, 2)
            big_block_frames = np.concatenate(pieces, axis=0)  # (m, L, 2) where m=group_size_

            # 表示变换
            if method_key == "XFR_BASE":
                # (m,L,2)->(L,m,2)
                out = np.transpose(big_block_frames, (1, 0, 2))
            elif method_key == "SHUFFLE_IN_BLOCK":
                # 在 cross-domain block 的 m 维度打乱，然后再 XFR transpose
                rng = np.random.default_rng(seed_ + stable_int_hash(tx_id) + 10007 * (g + 1))
                perm = rng.permutation(big_block_frames.shape[0])
                shuffled = big_block_frames[perm]
                out = np.transpose(shuffled, (1, 0, 2))  # (L,m,2)
            elif method_key == "CONCAT":
                # 拼接 m 帧: (m,L,2)->(1,m*L,2)
                out = big_block_frames.reshape(-1, 2)[None, :, :]
            elif method_key == "MEAN":
                # 平均: (m,L,2)->(1,L,2)
                out = np.mean(big_block_frames, axis=0, keepdims=True)
            else:
                raise ValueError(f"Unknown method_key: {method_key}")

            X_blocks_list.append(out)
            y_blocks_list.append(label_to_idx[tx_id])

    if len(X_blocks_list) == 0:
        raise RuntimeError("没有生成任何 block（检查 group_size / 文件数 / 样本数）。")

    X_blocks = np.stack(X_blocks_list, axis=0).astype(np.float32)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    meta = {
        "method_key": method_key,
        "num_files_total": len(mat_files),
        "skipped_files": skipped,
        "num_classes": len(label_to_idx),
        "num_blocks": int(X_blocks.shape[0]),
        "sample_len": int(X_blocks.shape[1]),
        "time_len": int(X_blocks.shape[2]),
        "group_size": int(group_size_),
        "samples_per_file_rule": "group_size // num_files_per_tx",
    }

    print(f"[INFO] X_blocks: {X_blocks.shape} (B, sample_len, time_len, 2) | classes={len(label_to_idx)}")
    return X_blocks, y_blocks, label_to_idx, meta


# ================= Row/Block 评估 =================
def evaluate_rowlevel(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            all_labels.extend(y.cpu().numpy().tolist())
            all_preds.extend(preds.cpu().numpy().tolist())
    acc = 100.0 * correct / total if total > 0 else 0.0
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
    return acc, cm

def evaluate_blocklevel(model, X_blocks, y_blocks, device, num_classes, rows_batch=512):
    """
    对每个 block：把 sample_len 行作为 batch 一次/分批前向，logits 求均值 -> block 预测
    """
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for i in range(X_blocks.shape[0]):
            xb = torch.tensor(X_blocks[i], dtype=torch.float32, device=device)  # (S, T, 2)
            yb = int(y_blocks[i])
            S = xb.shape[0]
            logits_list = []
            for s in range(0, S, rows_batch):
                e = min(S, s + rows_batch)
                logits_list.append(model(xb[s:e]))
            logits = torch.cat(logits_list, dim=0)  # (S, C)
            agg = logits.mean(dim=0)                # (C,)
            pred = int(torch.argmax(agg).item())
            preds.append(pred)
            labels.append(yb)
    acc = 100.0 * (np.mean(np.array(preds) == np.array(labels)) if len(labels) else 0.0)
    cm = confusion_matrix(labels, preds, labels=list(range(num_classes)))
    return acc, cm


# ================= 训练集每 epoch 每 block 抽 K 行 =================
def sample_rows_per_block_epoch(X_train_blocks, y_train_blocks, K, rng):
    """
    X_train_blocks: (B, sample_len, time_len, 2)
    返回: X_train_ep: (B*K, time_len, 2), y_train_ep: (B*K,)
    """
    B, S, T, C = X_train_blocks.shape
    if S == 1:
        K = 1
    if (K is None) or (K >= S):
        X = X_train_blocks.reshape(-1, T, C)
        y = np.repeat(y_train_blocks, S)
        return X, y

    idx = np.empty((B, K), dtype=np.int64)
    for b in range(B):
        idx[b] = rng.choice(S, size=K, replace=False)
    gather_idx = idx[:, :, None, None]  # (B,K,1,1)
    X_sel = np.take_along_axis(X_train_blocks, gather_idx, axis=1)  # (B,K,T,C)
    X = X_sel.reshape(-1, T, C)
    y = np.repeat(y_train_blocks, K)
    return X, y


# ================= Backbones =================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B,T,2)->(B,2,T)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x).squeeze(-1)
        return self.fc(x)

class CNN1D(nn.Module):
    """
    轻量 1D CNN：Conv-BN-ReLU-Pool 堆叠 + GAP
    """
    def __init__(self, num_classes, width=64, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(2, width, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm1d(width),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(3, stride=2, padding=1),

            nn.Conv1d(width, width*2, kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm1d(width*2),
            nn.ReLU(inplace=True),

            nn.Conv1d(width*2, width*4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(width*4),
            nn.ReLU(inplace=True),

            nn.Conv1d(width*4, width*8, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm1d(width*8),
            nn.ReLU(inplace=True),

            nn.Dropout(p=dropout),
        )
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(width*8, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.net(x)
        x = self.gap(x).squeeze(-1)
        return self.fc(x)

class TCNBlock(nn.Module):
    def __init__(self, channels, dilation, dropout=0.0):
        super().__init__()
        pad = dilation
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, dilation=dilation, padding=pad, bias=False)
        self.bn1 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.drop = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, dilation=dilation, padding=pad, bias=False)
        self.bn2 = nn.BatchNorm1d(channels)

    def forward(self, x):
        y = self.relu(self.bn1(self.conv1(x)))
        y = self.drop(y)
        y = self.bn2(self.conv2(y))
        return self.relu(x + y)

class TCN1D(nn.Module):
    """
    Dilated TCN：先映射到 channels，再堆叠多层 dilation residual block + GAP
    """
    def __init__(self, num_classes, channels=128, levels=6, dropout=0.5):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv1d(2, channels, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm1d(channels),
            nn.ReLU(inplace=True),
        )
        blocks = []
        for i in range(levels):
            blocks.append(TCNBlock(channels, dilation=2**i, dropout=dropout))
        self.tcn = nn.Sequential(*blocks)
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(channels, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.stem(x)
        x = self.tcn(x)
        x = self.gap(x).squeeze(-1)
        return self.fc(x)

def build_model(backbone_key, num_classes, dropout_=0.5):
    if backbone_key == "resnet":
        return ResNet18_1D(num_classes=num_classes, in_planes=64, dropout=dropout_)
    if backbone_key == "cnn":
        return CNN1D(num_classes=num_classes, width=64, dropout=dropout_)
    if backbone_key == "tcn":
        return TCN1D(num_classes=num_classes, channels=128, levels=6, dropout=dropout_)
    raise ValueError(f"Unknown backbone: {backbone_key}")


# ================= 主训练（单个 method+backbone） =================
def train_one_combo(method_name, method_key, backbone_name, backbone_key, save_folder):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] device={device} | method={method_name} | backbone={backbone_name}")

    # 1) 构造 cross-domain blocks
    X_blocks, y_blocks, label_to_idx, meta = load_and_preprocess_cross_domain_blocks(
        data_path,
        group_size_=group_size,
        apply_doppler_=apply_doppler,
        target_velocity_kmh_=v_kmh,
        apply_awgn_=apply_awgn,
        snr_db_=SNR_FIXED,
        fs_=fs,
        fc_=fc,
        method_key=method_key,
        seed_=seed,
    )
    num_classes = len(label_to_idx)

    # 2) block-level train/test split（stratify by y_blocks）
    idx_all = np.arange(X_blocks.shape[0])
    train_idx, test_idx = train_test_split(
        idx_all, test_size=test_size, stratify=y_blocks, random_state=seed
    )

    X_train_blocks_all = X_blocks[train_idx]
    y_train_blocks_all = y_blocks[train_idx]
    X_test_blocks = X_blocks[test_idx]
    y_test_blocks = y_blocks[test_idx]

    # row-level test loader
    S_test = X_test_blocks.shape[1]
    T_test = X_test_blocks.shape[2]
    bs_test = auto_batch_size_for_T(int(T_test), base=base_batch_size)

    X_test = X_test_blocks.reshape(-1, T_test, 2)
    y_test = np.repeat(y_test_blocks, S_test)
    test_loader = DataLoader(
        TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)),
        batch_size=bs_test, shuffle=False
    )

    # 日志文件
    results_txt = os.path.join(save_folder, "results.txt")
    metrics_csv = os.path.join(save_folder, "metrics.csv")
    with open(results_txt, "w", encoding="utf-8") as f:
        f.write(json.dumps({
            "timestamp": os.path.basename(save_folder),
            "method": method_name,
            "backbone": backbone_name,
            "SNR_dB": SNR_FIXED,
            "group_size": group_size,
            "meta": meta
        }, ensure_ascii=False, indent=2) + "\n\n")

    with open(metrics_csv, "w", encoding="utf-8") as f:
        f.write("method,backbone,fold,epoch,train_loss,train_acc,val_loss,val_acc_row,train_time_s,lr\n")

    # 3) KFold on train blocks
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed)

    fold_summaries = []
    fold_results = []

    for fold, (tr_sub, va_sub) in enumerate(kfold.split(X_train_blocks_all)):
        print(f"\n====== {method_name} | {backbone_name} | Fold {fold+1}/{n_splits} ======")

        X_tr_blocks = X_train_blocks_all[tr_sub]
        y_tr_blocks = y_train_blocks_all[tr_sub]
        X_va_blocks = X_train_blocks_all[va_sub]
        y_va_blocks = y_train_blocks_all[va_sub]

        # row-level val loader
        S_va = X_va_blocks.shape[1]
        T_va = X_va_blocks.shape[2]
        bs_val = auto_batch_size_for_T(int(T_va), base=base_batch_size)

        X_val = X_va_blocks.reshape(-1, T_va, 2)
        y_val = np.repeat(y_va_blocks, S_va)
        val_loader = DataLoader(
            TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)),
            batch_size=bs_val, shuffle=False
        )

        # model
        model = build_model(backbone_key, num_classes=num_classes, dropout_=dropout).to(device)
        params = count_params(model)
        params_bytes_est = sizeof_state_dict_bytes_est(model)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        best_val_acc = -1.0
        patience_counter = 0
        best_wts = None

        train_losses, val_losses = [], []
        train_times = []

        # epoch loop
        for epoch in range(num_epochs):
            # ======== build epoch train set by row sampling =========
            rng = np.random.default_rng(seed + 100000*(fold+1) + (epoch+1))
            X_tr_ep, y_tr_ep = sample_rows_per_block_epoch(X_tr_blocks, y_tr_blocks, train_rows_per_block, rng)

            T_tr = int(X_tr_ep.shape[1])
            bs_tr = auto_batch_size_for_T(T_tr, base=base_batch_size)
            train_loader = DataLoader(
                TensorDataset(torch.tensor(X_tr_ep, dtype=torch.float32), torch.tensor(y_tr_ep, dtype=torch.long)),
                batch_size=bs_tr, shuffle=True
            )

            # ======== train epoch (time only training) =========
            model.train()
            t0 = time.perf_counter()

            running_loss = 0.0
            correct, total = 0, 0
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                pred = torch.argmax(logits, dim=1)
                total += yb.size(0)
                correct += (pred == yb).sum().item()

            cuda_sync_if_needed(device)
            t1 = time.perf_counter()
            train_time = float(t1 - t0)
            train_times.append(train_time)

            train_loss = running_loss / max(1, len(train_loader))
            train_acc = 100.0 * correct / max(1, total)
            train_losses.append(train_loss)

            # ======== val (row-level) =========
            model.eval()
            vloss_sum = 0.0
            vcorrect, vtotal = 0, 0
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb = xb.to(device)
                    yb = yb.to(device)
                    logits = model(xb)
                    vloss = criterion(logits, yb)
                    vloss_sum += vloss.item()
                    pred = torch.argmax(logits, dim=1)
                    vtotal += yb.size(0)
                    vcorrect += (pred == yb).sum().item()
            val_loss = vloss_sum / max(1, len(val_loader))
            val_acc_row = 100.0 * vcorrect / max(1, vtotal)
            val_losses.append(val_loss)

            lr_now = optimizer.param_groups[0]["lr"]
            log = (f"Fold {fold+1} Ep {epoch+1} | "
                   f"Train loss {train_loss:.4f} acc {train_acc:.2f}% | "
                   f"Val loss {val_loss:.4f} acc_row {val_acc_row:.2f}% | "
                   f"train_time {train_time:.3f}s | lr {lr_now:.3g}")
            print(log)
            with open(results_txt, "a", encoding="utf-8") as f:
                f.write(log + "\n")
            with open(metrics_csv, "a", encoding="utf-8") as f:
                f.write(f"{method_name},{backbone_name},{fold+1},{epoch+1},"
                        f"{train_loss:.6f},{train_acc:.4f},{val_loss:.6f},{val_acc_row:.4f},"
                        f"{train_time:.6f},{lr_now:.10f}\n")

            # early stopping by val_acc_row
            if val_acc_row > best_val_acc + 0.01:
                best_val_acc = val_acc_row
                patience_counter = 0
                best_wts = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    msg = f"Early stop at epoch {epoch+1} (best_val_acc_row={best_val_acc:.2f}%)"
                    print(msg)
                    with open(results_txt, "a", encoding="utf-8") as f:
                        f.write(msg + "\n")
                    break

            scheduler.step()

        epochs_used = len(train_losses)

        # load best
        if best_wts is None:
            best_wts = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        model.load_state_dict(best_wts, strict=True)
        model.to(device)

        # ======== final eval row + block =========
        val_acc_row_final, val_cm_row = evaluate_rowlevel(model, val_loader, device, num_classes)
        test_acc_row, test_cm_row = evaluate_rowlevel(model, test_loader, device, num_classes)

        val_acc_block, val_cm_block = evaluate_blocklevel(model, X_va_blocks, y_va_blocks, device, num_classes)
        test_acc_block, test_cm_block = evaluate_blocklevel(model, X_test_blocks, y_test_blocks, device, num_classes)

        # save model
        model_path = os.path.join(save_folder, f"best_model_fold{fold+1}.pth")
        torch.save(best_wts, model_path)
        model_file_bytes = int(os.path.getsize(model_path))

        # confusion matrices
        plot_confusion_matrix(val_cm_row, os.path.join(save_folder, f"confmat_row_val_fold{fold+1}.png"), "Val Row-level")
        plot_confusion_matrix(test_cm_row, os.path.join(save_folder, f"confmat_row_test_fold{fold+1}.png"), "Test Row-level")
        plot_confusion_matrix(val_cm_block, os.path.join(save_folder, f"confmat_block_val_fold{fold+1}.png"), "Val Block-level")
        plot_confusion_matrix(test_cm_block, os.path.join(save_folder, f"confmat_block_test_fold{fold+1}.png"), "Test Block-level")

        # training time range
        t_min = float(np.min(train_times)) if len(train_times) else 0.0
        t_med = float(np.median(train_times)) if len(train_times) else 0.0
        t_max = float(np.max(train_times)) if len(train_times) else 0.0

        # ======== inference benchmark =========
        # pick one row sample + one block sample from test set
        # row sample: first row of first test block
        xb0 = torch.tensor(X_test_blocks[0, 0:1], dtype=torch.float32)          # (1,T,2)
        # block sample: all rows of first test block
        xb_blk = torch.tensor(X_test_blocks[0], dtype=torch.float32)            # (S,T,2)

        row_mean, row_p50, row_p95 = benchmark_inference(model, xb0, device, warmup=20, iters=100)
        blk_mean, blk_p50, blk_p95 = benchmark_inference(model, xb_blk, device, warmup=20, iters=100)

        fold_summary = {
            "method": method_name,
            "backbone": backbone_name,
            "fold": fold + 1,
            "epochs_used": int(epochs_used),
            "train_epoch_time_s_min": t_min,
            "train_epoch_time_s_med": t_med,
            "train_epoch_time_s_max": t_max,
            "val_acc_row": float(val_acc_row_final),
            "test_acc_row": float(test_acc_row),
            "val_acc_block": float(val_acc_block),
            "test_acc_block": float(test_acc_block),
            "params": int(params),
            "param_bytes_est": int(params_bytes_est),
            "model_file_bytes": int(model_file_bytes),
            "infer_row_ms_mean": row_mean,
            "infer_row_ms_p50": row_p50,
            "infer_row_ms_p95": row_p95,
            "infer_block_ms_mean": blk_mean,
            "infer_block_ms_p50": blk_p50,
            "infer_block_ms_p95": blk_p95,
            "time_len_T": int(T_test),
            "sample_len_S_test": int(S_test),
        }
        fold_summaries.append(fold_summary)

        fold_results.append({
            "train_loss": train_losses,
            "val_loss": val_losses,
        })

        msg2 = (f"[FOLD{fold+1}] test_row={test_acc_row:.2f}% test_block={test_acc_block:.2f}% | "
                f"epochs={epochs_used} | train_time(s) min/med/max={t_min:.3f}/{t_med:.3f}/{t_max:.3f} | "
                f"params={params/1e6:.3f}M | model={model_file_bytes/1024/1024:.2f}MB | "
                f"infer_row_mean={row_mean:.3f}ms infer_block_mean={blk_mean:.3f}ms")
        print(msg2)
        with open(results_txt, "a", encoding="utf-8") as f:
            f.write(msg2 + "\n")

    # loss curves
    plot_training_curves(fold_results, save_folder)

    # dump per-fold summary
    with open(os.path.join(save_folder, "fold_summaries.json"), "w", encoding="utf-8") as f:
        json.dump(fold_summaries, f, ensure_ascii=False, indent=2)

    return fold_summaries, meta


# ================= 主程序：跑完所有 method × backbone =================
def main():
    seed_everything(seed)
    os.makedirs(out_root, exist_ok=True)

    fd_int = int(compute_doppler_shift(v_kmh, fc))
    ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    root_folder = os.path.join(
        out_root,
        f"{ts}_LTEV_CrossDomainXFR_Ablations_SNR{SNR_FIXED}dB_fd{fd_int}_m{group_size}"
    )
    os.makedirs(root_folder, exist_ok=True)

    global_csv = os.path.join(root_folder, "global_summary.csv")
    with open(global_csv, "w", encoding="utf-8") as f:
        f.write(",".join([
            "method","backbone",
            "fold","epochs_used",
            "train_epoch_time_s_min","train_epoch_time_s_med","train_epoch_time_s_max",
            "val_acc_row","test_acc_row","val_acc_block","test_acc_block",
            "params","param_bytes_est","model_file_bytes",
            "infer_row_ms_mean","infer_row_ms_p50","infer_row_ms_p95",
            "infer_block_ms_mean","infer_block_ms_p50","infer_block_ms_p95",
            "time_len_T","sample_len_S_test"
        ]) + "\n")

    all_records = []

    for method_name, method_key in METHODS:
        for backbone_name, backbone_key in BACKBONES:
            combo_folder = os.path.join(root_folder, f"{method_name}__{backbone_name}")
            os.makedirs(combo_folder, exist_ok=True)

            fold_summaries, meta = train_one_combo(
                method_name, method_key, backbone_name, backbone_key, combo_folder
            )

            # append to global csv
            with open(global_csv, "a", encoding="utf-8") as f:
                for r in fold_summaries:
                    f.write(",".join([
                        r["method"], r["backbone"],
                        str(r["fold"]), str(r["epochs_used"]),
                        f"{r['train_epoch_time_s_min']:.6f}", f"{r['train_epoch_time_s_med']:.6f}", f"{r['train_epoch_time_s_max']:.6f}",
                        f"{r['val_acc_row']:.6f}", f"{r['test_acc_row']:.6f}",
                        f"{r['val_acc_block']:.6f}", f"{r['test_acc_block']:.6f}",
                        str(r["params"]), str(r["param_bytes_est"]), str(r["model_file_bytes"]),
                        f"{r['infer_row_ms_mean']:.6f}", f"{r['infer_row_ms_p50']:.6f}", f"{r['infer_row_ms_p95']:.6f}",
                        f"{r['infer_block_ms_mean']:.6f}", f"{r['infer_block_ms_p50']:.6f}", f"{r['infer_block_ms_p95']:.6f}",
                        str(r["time_len_T"]), str(r["sample_len_S_test"]),
                    ]) + "\n")

            all_records.extend(fold_summaries)

    # 全局汇总（按 method+backbone 聚合 mean/std）
    def group_key(r): return (r["method"], r["backbone"])
    groups = {}
    for r in all_records:
        groups.setdefault(group_key(r), []).append(r)

    agg = []
    for (m, b), lst in groups.items():
        def mean_std(x):
            x = np.array(x, dtype=np.float64)
            return float(x.mean()), float(x.std())

        epochs = [x["epochs_used"] for x in lst]
        row_acc = [x["test_acc_row"] for x in lst]
        blk_acc = [x["test_acc_block"] for x in lst]
        t_med = [x["train_epoch_time_s_med"] for x in lst]
        infer_row = [x["infer_row_ms_mean"] for x in lst]
        infer_blk = [x["infer_block_ms_mean"] for x in lst]
        model_mb = [x["model_file_bytes"]/1024/1024 for x in lst]
        params = lst[0]["params"]

        agg.append({
            "method": m, "backbone": b,
            "epochs_mean": float(np.mean(epochs)), "epochs_std": float(np.std(epochs)),
            "test_row_mean": mean_std(row_acc)[0], "test_row_std": mean_std(row_acc)[1],
            "test_block_mean": mean_std(blk_acc)[0], "test_block_std": mean_std(blk_acc)[1],
            "train_epoch_time_med_s_mean": mean_std(t_med)[0], "train_epoch_time_med_s_std": mean_std(t_med)[1],
            "infer_row_ms_mean": mean_std(infer_row)[0], "infer_row_ms_std": mean_std(infer_row)[1],
            "infer_block_ms_mean": mean_std(infer_blk)[0], "infer_block_ms_std": mean_std(infer_blk)[1],
            "params": int(params),
            "model_file_MB_mean": float(np.mean(model_mb)), "model_file_MB_std": float(np.std(model_mb)),
        })

    with open(os.path.join(root_folder, "aggregate_summary.json"), "w", encoding="utf-8") as f:
        json.dump(agg, f, ensure_ascii=False, indent=2)

    print(f"[DONE] root_folder: {root_folder}")
    print(f"[DONE] global_summary.csv: {global_csv}")
    print(f"[DONE] aggregate_summary.json saved.")


if __name__ == "__main__":
    main()


[INFO] device=cuda | method=XFR_BASE | backbone=ResNet18_1D
[INFO] Cross-domain loader | files=72 | fd=655.56 Hz | method=XFR_BASE


读取数据: 100%|██████████| 72/72 [00:10<00:00,  6.89it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 0.6236 acc 78.25% | Val loss 0.0317 acc_row 99.15% | train_time 11.729s | lr 0.0001
Fold 1 Ep 2 | Train loss 0.0366 acc 99.08% | Val loss 0.0201 acc_row 99.46% | train_time 11.915s | lr 0.0001
Fold 1 Ep 3 | Train loss 0.0197 acc 99.51% | Val loss 0.0164 acc_row 99.47% | train_time 11.724s | lr 0.0001
Fold 1 Ep 4 | Train loss 0.0141 acc 99.69% | Val loss 0.0145 acc_row 99.58% | train_time 11.518s | lr 0.0001
Fold 1 Ep 5 | Train loss 0.0116 acc 99.78% | Val loss 0.0164 acc_row 99.46% | train_time 11.472s | lr 0.0001
Fold 1 Ep 6 | Train loss 0.0108 acc 99.83% | Val loss 0.0107 acc_row 99.67% | train_time 11.451s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.0100 acc 99.86% | Val loss 0.0082 acc_row 99.73% | train_time 11.663s | lr 0.0001
Fold 1 Ep 8 | Train loss 0.0094 acc 99.87% | Val loss 0.0089 acc_row 99.70% | train_time 11.723s | lr 0.0001
Fold 1 Ep 9 | Train loss 0.0083 acc 99.89% | Val l

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.12it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 0.4273 acc 91.50% | Val loss 0.1341 acc_row 97.52% | train_time 4.285s | lr 0.0001
Fold 1 Ep 2 | Train loss 0.0336 acc 99.80% | Val loss 0.0852 acc_row 98.24% | train_time 4.416s | lr 0.0001
Fold 1 Ep 3 | Train loss 0.0150 acc 99.97% | Val loss 0.0626 acc_row 98.67% | train_time 4.244s | lr 0.0001
Fold 1 Ep 4 | Train loss 0.0108 acc 99.99% | Val loss 0.0630 acc_row 98.75% | train_time 4.156s | lr 0.0001
Fold 1 Ep 5 | Train loss 0.0104 acc 99.99% | Val loss 0.0569 acc_row 99.05% | train_time 4.462s | lr 0.0001
Fold 1 Ep 6 | Train loss 0.0097 acc 100.00% | Val loss 0.0477 acc_row 99.45% | train_time 4.267s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.0094 acc 99.99% | Val loss 0.0478 acc_row 99.59% | train_time 4.205s | lr 0.0001
Fold 1 Ep 8 | Train loss 0.0087 acc 100.00% | Val loss 0.0430 acc_row 99.56% | train_time 4.181s | lr 0.0001
Fold 1 Ep 9 | Train loss 0.0085 acc 100.00% | Val loss 0

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.11it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 0.2899 acc 91.57% | Val loss 0.0290 acc_row 99.14% | train_time 63.458s | lr 0.0001
Fold 1 Ep 2 | Train loss 0.0181 acc 99.77% | Val loss 0.0148 acc_row 99.55% | train_time 61.653s | lr 0.0001
Fold 1 Ep 3 | Train loss 0.0096 acc 99.90% | Val loss 0.0064 acc_row 99.86% | train_time 63.411s | lr 0.0001
Fold 1 Ep 4 | Train loss 0.0081 acc 99.91% | Val loss 0.0101 acc_row 99.65% | train_time 61.611s | lr 0.0001
Fold 1 Ep 5 | Train loss 0.0073 acc 99.93% | Val loss 0.0092 acc_row 99.69% | train_time 64.342s | lr 0.0001
Fold 1 Ep 6 | Train loss 0.0072 acc 99.94% | Val loss 0.0111 acc_row 99.69% | train_time 61.872s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.0063 acc 99.96% | Val loss 0.0044 acc_row 99.90% | train_time 63.684s | lr 0.0001
Fold 1 Ep 8 | Train loss 0.0060 acc 99.94% | Val loss 0.0025 acc_row 99.97% | train_time 61.923s | lr 0.0001
Fold 1 Ep 9 | Train loss 0.0056 acc 99.96% | Val l

读取数据: 100%|██████████| 72/72 [00:10<00:00,  6.87it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.1604 acc 16.20% | Val loss 2.4719 acc_row 13.85% | train_time 11.553s | lr 0.0001
Fold 1 Ep 2 | Train loss 2.0357 acc 23.56% | Val loss 2.2685 acc_row 19.69% | train_time 11.424s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.8837 acc 31.12% | Val loss 2.3709 acc_row 19.12% | train_time 12.671s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.6596 acc 40.90% | Val loss 2.4984 acc_row 18.21% | train_time 11.973s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.4006 acc 50.84% | Val loss 2.7232 acc_row 17.24% | train_time 11.516s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.1500 acc 60.26% | Val loss 2.9178 acc_row 17.15% | train_time 11.358s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.9339 acc 68.23% | Val loss 3.0693 acc_row 16.59% | train_time 11.435s | lr 0.0001
Early stop at epoch 7 (best_val_acc_row=19.69%)
[FOLD1] test_row=18.63% test_block=29.29% | epochs=7 | train_time(s) min/med/max=11.358/11.516/12.671 | params=

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.08it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 1.8943 acc 31.70% | Val loss 2.3167 acc_row 16.73% | train_time 4.176s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.4149 acc 52.07% | Val loss 2.6339 acc_row 16.95% | train_time 4.161s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.0806 acc 64.01% | Val loss 2.8199 acc_row 15.89% | train_time 4.540s | lr 0.0001
Fold 1 Ep 4 | Train loss 0.8628 acc 71.50% | Val loss 3.0259 acc_row 15.89% | train_time 4.255s | lr 0.0001
Fold 1 Ep 5 | Train loss 0.7086 acc 76.86% | Val loss 3.2478 acc_row 15.33% | train_time 4.110s | lr 0.0001
Fold 1 Ep 6 | Train loss 0.5977 acc 80.67% | Val loss 3.3620 acc_row 15.92% | train_time 4.049s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.5104 acc 83.52% | Val loss 3.5353 acc_row 15.25% | train_time 4.052s | lr 0.0001
Early stop at epoch 7 (best_val_acc_row=16.95%)
[FOLD1] test_row=16.93% test_block=27.27% | epochs=7 | train_time(s) min/med/max=4.049/4.161/4.540 | params=0.540M | m

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.03it/s]


[INFO] X_blocks: (789, 256, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.0528 acc 22.51% | Val loss 2.3158 acc_row 19.79% | train_time 64.468s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.8824 acc 30.74% | Val loss 2.2075 acc_row 24.88% | train_time 61.409s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.7635 acc 35.98% | Val loss 2.5280 acc_row 23.98% | train_time 64.305s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.6454 acc 40.76% | Val loss 2.7801 acc_row 24.01% | train_time 61.343s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.5381 acc 45.22% | Val loss 2.5826 acc_row 24.18% | train_time 63.958s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.4441 acc 48.81% | Val loss 2.8531 acc_row 23.89% | train_time 62.773s | lr 0.0001
Fold 1 Ep 7 | Train loss 1.3614 acc 51.90% | Val loss 2.8838 acc_row 22.76% | train_time 62.222s | lr 0.0001
Early stop at epoch 7 (best_val_acc_row=24.88%)
[FOLD1] test_row=22.15% test_block=35.86% | epochs=7 | train_time(s) min/med/max=61.343/62.773/64.468 | params=

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.08it/s]


[INFO] X_blocks: (789, 1, 65536, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.0581 acc 22.03% | Val loss 4.1562 acc_row 10.08% | train_time 2.916s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.7054 acc 40.25% | Val loss 1.6691 acc_row 36.97% | train_time 2.803s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.4437 acc 57.84% | Val loss 2.1195 acc_row 26.05% | train_time 2.883s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.2325 acc 67.80% | Val loss 1.0822 acc_row 60.50% | train_time 2.975s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.1558 acc 77.12% | Val loss 1.5402 acc_row 36.97% | train_time 2.877s | lr 0.0001
Fold 1 Ep 6 | Train loss 0.9858 acc 81.36% | Val loss 1.0157 acc_row 53.78% | train_time 2.806s | lr 0.0001
Fold 1 Ep 7 | Train loss 0.8758 acc 84.96% | Val loss 0.8460 acc_row 63.87% | train_time 2.792s | lr 0.0001
Fold 1 Ep 8 | Train loss 0.7448 acc 87.50% | Val loss 0.5775 acc_row 80.67% | train_time 2.788s | lr 0.0001
Fold 1 Ep 9 | Train loss 0.6949 acc 89.83% | Val loss 0.42

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.10it/s]


[INFO] X_blocks: (789, 1, 65536, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.1411 acc 20.55% | Val loss 1.9072 acc_row 48.74% | train_time 0.728s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.9730 acc 44.07% | Val loss 1.7062 acc_row 63.87% | train_time 0.564s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.8541 acc 51.69% | Val loss 1.5746 acc_row 82.35% | train_time 0.566s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.7788 acc 61.23% | Val loss 1.4828 acc_row 78.99% | train_time 0.569s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.7266 acc 65.04% | Val loss 1.4155 acc_row 75.63% | train_time 0.564s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.6311 acc 74.79% | Val loss 1.3487 acc_row 87.39% | train_time 0.589s | lr 0.0001
Fold 1 Ep 7 | Train loss 1.5645 acc 77.33% | Val loss 1.2723 acc_row 85.71% | train_time 0.580s | lr 0.0001
Fold 1 Ep 8 | Train loss 1.5075 acc 78.18% | Val loss 1.1920 acc_row 95.80% | train_time 0.580s | lr 0.0001
Fold 1 Ep 9 | Train loss 1.4421 acc 83.05% | Val loss 1.14

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.10it/s]


[INFO] X_blocks: (789, 1, 65536, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.1989 acc 16.31% | Val loss 1.8739 acc_row 23.53% | train_time 7.796s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.8990 acc 35.59% | Val loss 1.5245 acc_row 45.38% | train_time 7.734s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.7054 acc 45.13% | Val loss 1.3002 acc_row 58.82% | train_time 7.706s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.5732 acc 53.81% | Val loss 1.0139 acc_row 79.83% | train_time 7.841s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.4383 acc 62.50% | Val loss 0.8766 acc_row 90.76% | train_time 7.676s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.3551 acc 70.34% | Val loss 0.8084 acc_row 84.03% | train_time 7.703s | lr 0.0001
Fold 1 Ep 7 | Train loss 1.3684 acc 64.41% | Val loss 0.6895 acc_row 96.64% | train_time 7.712s | lr 0.0001
Fold 1 Ep 8 | Train loss 1.2216 acc 74.36% | Val loss 0.6003 acc_row 96.64% | train_time 7.707s | lr 0.0001
Fold 1 Ep 9 | Train loss 1.1482 acc 80.08% | Val loss 0.52

读取数据: 100%|██████████| 72/72 [00:10<00:00,  6.98it/s]


[INFO] X_blocks: (789, 1, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.2312 acc 12.71% | Val loss 2.2434 acc_row 10.08% | train_time 0.077s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.9988 acc 18.43% | Val loss 2.6007 acc_row 10.08% | train_time 0.062s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.9318 acc 23.73% | Val loss 3.1713 acc_row 10.08% | train_time 0.060s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.8851 acc 25.64% | Val loss 3.6523 acc_row 10.08% | train_time 0.058s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.8581 acc 27.75% | Val loss 4.0037 acc_row 10.08% | train_time 0.059s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.8023 acc 31.78% | Val loss 4.1668 acc_row 10.08% | train_time 0.048s | lr 0.0001
Early stop at epoch 6 (best_val_acc_row=10.08%)
[FOLD1] test_row=11.11% test_block=11.11% | epochs=6 | train_time(s) min/med/max=0.048/0.059/0.077 | params=3.849M | model=14.76MB | infer_row_mean=1.509ms infer_block_mean=1.971ms

Fold 2 Ep 1 | Train loss 2.1645 acc 13.95% | V

读取数据: 100%|██████████| 72/72 [00:10<00:00,  7.05it/s]


[INFO] X_blocks: (789, 1, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.1659 acc 16.95% | Val loss 2.2046 acc_row 10.08% | train_time 0.036s | lr 0.0001
Fold 1 Ep 2 | Train loss 1.9695 acc 31.14% | Val loss 2.2568 acc_row 10.08% | train_time 0.034s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.8334 acc 36.23% | Val loss 2.4129 acc_row 10.08% | train_time 0.034s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.7365 acc 39.83% | Val loss 2.6899 acc_row 10.08% | train_time 0.030s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.6339 acc 47.88% | Val loss 2.9370 acc_row 10.08% | train_time 0.030s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.5827 acc 49.58% | Val loss 3.0157 acc_row 10.08% | train_time 0.023s | lr 0.0001
Early stop at epoch 6 (best_val_acc_row=10.08%)
[FOLD1] test_row=11.11% test_block=11.11% | epochs=6 | train_time(s) min/med/max=0.023/0.032/0.036 | params=0.540M | model=2.08MB | infer_row_mean=0.372ms infer_block_mean=0.368ms

Fold 2 Ep 1 | Train loss 2.1504 acc 17.76% | Va

读取数据: 100%|██████████| 72/72 [00:10<00:00,  6.99it/s]


[INFO] X_blocks: (789, 1, 256, 2) (B, sample_len, time_len, 2) | classes=9

Fold 1 Ep 1 | Train loss 2.3456 acc 11.02% | Val loss 2.2257 acc_row 10.08% | train_time 0.457s | lr 0.0001
Fold 1 Ep 2 | Train loss 2.0360 acc 19.92% | Val loss 2.4438 acc_row 10.08% | train_time 0.264s | lr 0.0001
Fold 1 Ep 3 | Train loss 1.9136 acc 28.60% | Val loss 2.9628 acc_row 10.08% | train_time 0.250s | lr 0.0001
Fold 1 Ep 4 | Train loss 1.8325 acc 27.54% | Val loss 3.3904 acc_row 10.08% | train_time 0.250s | lr 0.0001
Fold 1 Ep 5 | Train loss 1.8072 acc 31.99% | Val loss 3.5482 acc_row 10.08% | train_time 0.263s | lr 0.0001
Fold 1 Ep 6 | Train loss 1.6926 acc 38.56% | Val loss 3.5106 acc_row 10.08% | train_time 0.252s | lr 0.0001
Early stop at epoch 6 (best_val_acc_row=10.08%)
[FOLD1] test_row=11.11% test_block=11.11% | epochs=6 | train_time(s) min/med/max=0.250/0.257/0.457 | params=0.596M | model=2.31MB | infer_row_mean=1.072ms infer_block_mean=1.066ms

Fold 2 Ep 1 | Train loss 2.2755 acc 16.07% | Va