In [None]:
# ResNet 1D 固定 SNR=20dB、contiguous-m 实验脚本（按 block 整体划分训练/测试）
# 目标：
#   - 去掉 kspacing 功能
#   - 固定在 20 dB 下做 contiguous
#   - 循环不同 m，比较准确率变化

import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split, StratifiedKFold
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import h5py

# ================= 参数设置 =================
data_path = "D:/users/zhongyuan/rf_datasets"  # 数据文件夹
fs = 5e6
fc = 5.9e9
v = 120                 # km/h（用于你的人为 Doppler 注入）
apply_doppler = True
apply_awgn = True

# -------- 固定实验设置 --------
SNR_FIXED_DB = 20
# 你在这里定义要扫描的 m 列表（示例：从小到大）
M_LIST = list(range(8, 577, 16))  # 按需改

# 训练/测试按 block 划分
test_size = 0.25
random_state = 42

# 模型超参数
batch_size = 64
num_epochs = 300
learning_rate = 1e-4
weight_decay = 1e-3
in_planes = 64
dropout = 0.5
patience = 5
n_splits = 5

# ================= 多普勒和AWGN处理函数 =================
def compute_doppler_shift(v_kmh, fc_hz):
    c = 3e8
    v_mps = v_kmh / 3.6
    return (v_mps / c) * fc_hz

def apply_doppler_shift(signal, fd, fs):
    t = np.arange(signal.shape[-1]) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / (10**(snr_db/10))
    noise_real = np.random.randn(*signal.shape)
    noise_imag = np.random.randn(*signal.shape)
    noise = np.sqrt(noise_power/2) * (noise_real + 1j*noise_imag)
    return signal + noise

# ================= 数据加载（按单文件生成 blocks，并翻转 block） =================
def load_and_preprocess_blocks_singlefile_contiguous(
    mat_folder,
    m=288,
    apply_doppler=False,
    target_velocity=30,
    apply_awgn=False,
    snr_db=20,
    fs=5e6,
    fc=5.9e9
):
    """
    contiguous-m 版本（去掉 kspacing）：
      - 每个 block 由连续 m 帧组成
      - span = m
      - stride = m（非重叠、不滑窗）
      - 每个 block 只来自单个 .mat 文件（不跨文件拼接）

    输出：
      X_blocks: (num_blocks, sample_len, m, 2)
      y_blocks: (num_blocks,)
      label_to_idx: dict(tx_id -> class index)
    """
    mat_files = sorted(glob.glob(os.path.join(mat_folder, "*.mat")))
    print(f"[INFO] 共找到 {len(mat_files)} 个 .mat 文件")

    fd = compute_doppler_shift(target_velocity, fc)
    print(f"[INFO] Doppler 设置：v={target_velocity} km/h, fd={fd:.2f} Hz, apply_doppler={apply_doppler}")
    print(f"[INFO] XFR 设置：contiguous, m={m}")

    # 先读出所有文件的 tx_id，建立 label mapping
    y_files = []
    X_files = []

    for file in tqdm(mat_files, desc="读取并预处理文件"):
        with h5py.File(file, "r") as f:
            rfDataset = f["rfDataset"]
            dmrs_struct = rfDataset["dmrs"][:]
            dmrs_complex = dmrs_struct["real"] + 1j * dmrs_struct["imag"]   # (num_frames, sample_len)

            txID_uint16 = rfDataset["txID"][:].flatten()
            tx_id = "".join(chr(c) for c in txID_uint16 if c != 0)

        # 逐帧预处理为 IQ (num_frames, sample_len, 2)
        processed_signals = []
        for i in range(dmrs_complex.shape[0]):
            sig = dmrs_complex[i, :]

            # step1: 功率归一化
            sig = sig / (np.sqrt(np.mean(np.abs(sig)**2)) + 1e-12)

            # step2: Doppler（只改相位）
            if apply_doppler:
                sig = apply_doppler_shift(sig, fd, fs)

            # step3: AWGN
            if apply_awgn:
                sig = add_awgn(sig, snr_db)

            iq = np.stack((sig.real, sig.imag), axis=-1)  # (sample_len, 2)
            processed_signals.append(iq)

        processed_signals = np.array(processed_signals, dtype=np.float32)  # (num_frames, sample_len, 2)
        X_files.append(processed_signals)
        y_files.append(tx_id)

    label_list = sorted(list(set(y_files)))
    label_to_idx = {label: i for i, label in enumerate(label_list)}

    # 生成 blocks（contiguous）
    X_blocks_list = []
    y_blocks_list = []

    span = m
    stride = m

    for Xf, tx_id in tqdm(list(zip(X_files, y_files)), desc="生成 blocks"):
        num_frames, sample_len, _ = Xf.shape
        if num_frames < span:
            continue

        # start: 0, m, 2m, ...
        for start in range(0, num_frames - span + 1, stride):
            idx = slice(start, start + m)  # contiguous
            block_raw = Xf[idx, :, :]      # (m, sample_len, 2)
            block = np.transpose(block_raw, (1, 0, 2))  # -> (sample_len, m, 2)
            X_blocks_list.append(block)
            y_blocks_list.append(label_to_idx[tx_id])

    if len(X_blocks_list) == 0:
        raise RuntimeError("没有生成任何 block。请检查 m 是否过大，或文件帧数不足。")

    X_blocks = np.stack(X_blocks_list, axis=0)  # (num_blocks, sample_len, m, 2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)

    print(f"[INFO] 生成 block 数: {X_blocks.shape[0]}")
    print(f"[INFO] block 形状: (sample_len={X_blocks.shape[1]}, m={X_blocks.shape[2]}, 2)")
    print(f"[INFO] 类别数: {len(label_to_idx)}")
    return X_blocks, y_blocks, label_to_idx

# ================= 1D ResNet18（增加 dropout 和 in_planes） =================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x: (B, m, 2)
        x = x.permute(0, 2, 1)  # (B,2,m)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.squeeze(-1)
        return self.fc(x)

# ================= 辅助函数 =================
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += (p.grad.data.norm(2).item()) ** 2
    return total_norm ** 0.5

def moving_average(x, w=5):
    x = np.array(x)
    if len(x) == 0:
        return np.array([])
    w = max(1, min(w, len(x)))
    return np.convolve(x, np.ones(w), "valid") / w

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    acc = 100.0 * correct / max(1, total)
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    return acc, cm

def plot_training_curves(fold_results, save_folder):
    plt.figure(figsize=(12, 5))
    for i, res in enumerate(fold_results):
        plt.plot(moving_average(res["train_loss"]), label=f"Fold{i+1} Train Loss")
        plt.plot(moving_average(res["val_loss"]), label=f"Fold{i+1} Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("训练和验证Loss曲线")
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(save_folder, "loss_curves.png"))
    plt.close()

def plot_grad_norms(avg_grad_norms, save_folder):
    plt.figure(figsize=(6, 4))
    plt.bar(range(1, len(avg_grad_norms) + 1), avg_grad_norms)
    plt.xlabel("Fold")
    plt.ylabel("平均梯度范数")
    plt.title("各Fold平均梯度范数")
    plt.grid()
    plt.savefig(os.path.join(save_folder, "avg_grad_norms.png"))
    plt.close()

def plot_confusion_matrix(cm, save_path=None):
    plt.figure(figsize=(8, 6))
    plt.rcParams["font.sans-serif"] = ["SimHei"]
    plt.rcParams["axes.unicode_minus"] = False
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.ylabel("Reference")
    plt.xlabel("Predicted")
    if save_path:
        plt.savefig(save_path)
    plt.close()

def check_block_overlap(train_blocks_idx, val_blocks_idx, test_blocks_idx):
    train_set = set(train_blocks_idx)
    val_set = set(val_blocks_idx) if val_blocks_idx is not None else set()
    test_set = set(test_blocks_idx)
    overlap_train_val = train_set & val_set
    overlap_train_test = train_set & test_set
    overlap_val_test = val_set & test_set
    if overlap_train_val or overlap_train_test or overlap_val_test:
        raise RuntimeError(
            "[ERROR] Block 重叠检测失败！"
            f"\nTrain-Val overlap: {overlap_train_val}"
            f"\nTrain-Test overlap: {overlap_train_test}"
            f"\nVal-Test overlap: {overlap_val_test}"
        )
    else:
        print("[INFO] Block 重叠检查通过，训练/验证/测试 block 互斥。")

# ================= 主训练函数（固定 contiguous） =================
def train_for_fixed_snr_contiguous(SNR_dB, save_folder, results_file, m=288):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] 使用设备: {device}")

    # 1) 先生成全部 blocks
    X_blocks, y_blocks, label_to_idx = load_and_preprocess_blocks_singlefile_contiguous(
        data_path,
        m=m,
        apply_doppler=apply_doppler,
        target_velocity=v,
        apply_awgn=apply_awgn,
        snr_db=SNR_dB,
        fs=fs,
        fc=fc
    )
    num_blocks = X_blocks.shape[0]
    print(f"[INFO] 总 block 数: {num_blocks}")

    # 2) 按 block 划分 train/test
    block_idx = np.arange(num_blocks)
    train_block_idx, test_block_idx = train_test_split(
        block_idx, test_size=test_size, stratify=y_blocks, random_state=random_state
    )
    check_block_overlap(train_block_idx, None, test_block_idx)

    X_train_blocks = X_blocks[train_block_idx]
    y_train_blocks = y_blocks[train_block_idx]
    X_test_blocks  = X_blocks[test_block_idx]
    y_test_blocks  = y_blocks[test_block_idx]

    # 展开测试 blocks：每个 block 贡献 sample_len 个样本，每个样本 length=m
    X_test = X_test_blocks.reshape(-1, X_test_blocks.shape[2], X_test_blocks.shape[3])  # (-1, m, 2)
    y_test = np.repeat(y_test_blocks, X_test_blocks.shape[1])
    test_loader = DataLoader(
        TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)),
        batch_size=batch_size,
        shuffle=False
    )

    print(f"[INFO] 训练 block 数: {len(train_block_idx)}, 测试 block 数: {len(test_block_idx)}")

    # 3) KFold 在训练 blocks 内划分 train/val（按 block）
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = []
    fold_test_accs = []

    for fold, (train_idx_fold, val_idx_fold) in enumerate(kfold.split(X_train_blocks, y_train_blocks)):
        print(f"\n====== Fold {fold+1}/{n_splits} (m={m}, SNR={SNR_dB}dB) ======")

        X_train = X_train_blocks[train_idx_fold].reshape(-1, X_train_blocks.shape[2], X_train_blocks.shape[3])
        y_train = np.repeat(y_train_blocks[train_idx_fold], X_train_blocks.shape[1])

        train_loader = DataLoader(
            TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
            batch_size=batch_size,
            shuffle=True
        )

        X_val = X_train_blocks[val_idx_fold].reshape(-1, X_train_blocks.shape[2], X_train_blocks.shape[3])
        y_val = np.repeat(y_train_blocks[val_idx_fold], X_train_blocks.shape[1])

        val_loader = DataLoader(
            TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)),
            batch_size=batch_size,
            shuffle=False
        )

        model = ResNet18_1D(num_classes=len(label_to_idx), in_planes=in_planes, dropout=dropout).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        best_val_acc = 0.0
        best_model_wts = None
        patience_counter = 0
        train_losses, val_losses, grad_norms = [], [], []

        all_val_labels, all_val_preds = [], []

        for epoch in range(num_epochs):
            model.train()
            running_loss, correct_train, total_train = 0.0, 0, 0
            batch_grad_norms = []

            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                batch_grad_norms.append(compute_grad_norm(model))
                optimizer.step()

                running_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                total_train += labels.size(0)
                correct_train += (preds == labels).sum().item()

            train_loss = running_loss / max(1, len(train_loader))
            train_acc = 100.0 * correct_train / max(1, total_train)
            avg_grad_norm = float(np.mean(batch_grad_norms)) if batch_grad_norms else 0.0
            train_losses.append(train_loss)
            grad_norms.append(avg_grad_norm)

            # 验证
            model.eval()
            running_val_loss, correct_val, total_val = 0.0, 0, 0
            all_val_labels, all_val_preds = [], []
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    val_outputs = model(val_inputs)
                    loss_val = criterion(val_outputs, val_labels)
                    running_val_loss += loss_val.item()
                    val_preds = torch.argmax(val_outputs, dim=1)
                    total_val += val_labels.size(0)
                    correct_val += (val_preds == val_labels).sum().item()
                    all_val_labels.extend(val_labels.cpu().numpy())
                    all_val_preds.extend(val_preds.cpu().numpy())

            val_loss = running_val_loss / max(1, len(val_loader))
            val_acc = 100.0 * correct_val / max(1, total_val)
            val_losses.append(val_loss)

            log_msg = (f"m={m}, Fold {fold+1}, Epoch {epoch+1}: "
                       f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
                       f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%, Grad Norm={avg_grad_norm:.4f}")
            print(log_msg)
            with open(results_file, "a", encoding="utf-8") as f:
                f.write(log_msg + "\n")

            # 早停
            if val_acc > best_val_acc + 0.01:
                best_val_acc = val_acc
                patience_counter = 0
                best_model_wts = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    msg = f"早停，连续 {patience} 个 epoch 验证集未提升"
                    print(msg)
                    with open(results_file, "a", encoding="utf-8") as f:
                        f.write(msg + "\n")
                    break

            scheduler.step()

        # 测试集评估（用 best weights）
        model.load_state_dict(best_model_wts)
        test_acc, test_cm = evaluate_model(model, test_loader, device, len(label_to_idx))
        fold_test_accs.append(test_acc)

        val_cm = confusion_matrix(all_val_labels, all_val_preds, labels=range(len(label_to_idx)))
        plot_confusion_matrix(val_cm, save_path=os.path.join(save_folder, f"confusion_matrix_val_fold{fold+1}.png"))
        plot_confusion_matrix(test_cm, save_path=os.path.join(save_folder, f"confusion_matrix_test_fold{fold+1}.png"))
        torch.save(best_model_wts, os.path.join(save_folder, f"best_model_fold{fold+1}.pth"))

        fold_results.append({"train_loss": train_losses, "val_loss": val_losses, "grad_norms": grad_norms})
        print(f"m={m} | Fold {fold+1} Test Acc={test_acc:.2f}%\n")
        with open(results_file, "a", encoding="utf-8") as f:
            f.write(f"m={m} | Fold {fold+1} Test Acc={test_acc:.2f}%\n")

    plot_training_curves(fold_results, save_folder)
    plot_grad_norms([np.mean(fr["grad_norms"]) for fr in fold_results], save_folder)

    return float(np.mean(fold_test_accs))

# ================= 主程序：固定 SNR=20dB，循环不同 m =================
if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    base_folder = os.path.join(os.getcwd(), "training_results", f"{timestamp}_LTEV_contiguous_SNR{SNR_FIXED_DB}dB_fd{int(compute_doppler_shift(v, fc))}_ResNet_m_sweep")
    os.makedirs(base_folder, exist_ok=True)

    summary_path = os.path.join(base_folder, "m_sweep_summary.txt")
    with open(summary_path, "a", encoding="utf-8") as f:
        f.write(f"==== m sweep @ SNR={SNR_FIXED_DB} dB | {timestamp} ====\n")
        f.write(f"M_LIST={M_LIST}\n\n")

    m_accs = []

    for m in M_LIST:
        print(f"\n\n================== 当前实验：SNR={SNR_FIXED_DB} dB, contiguous, m={m} ==================\n")

        exp_name = f"m{m}_SNR{SNR_FIXED_DB}dB"
        save_folder = os.path.join(base_folder, exp_name)
        os.makedirs(save_folder, exist_ok=True)

        results_file = os.path.join(save_folder, "results.txt")
        with open(results_file, "a", encoding="utf-8") as f:
            f.write(f"\n================ SNR={SNR_FIXED_DB} dB | contiguous | m={m} =================\n")

        mean_test_acc = train_for_fixed_snr_contiguous(SNR_FIXED_DB, save_folder, results_file, m=m)
        m_accs.append(mean_test_acc)

        line = f"m={m} -> mean test acc = {mean_test_acc:.2f}% | results in: {save_folder}"
        print(line)
        with open(summary_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")

    # 绘制 m vs accuracy 曲线
    plt.figure(figsize=(8, 5))
    plt.plot(M_LIST, m_accs, marker="o", linestyle="-")
    plt.xlabel("m (contiguous frames per sample)")
    plt.ylabel("Mean Test Accuracy (%)")
    plt.title(f"m vs Accuracy @ SNR={SNR_FIXED_DB} dB (contiguous)")
    plt.grid(True)
    fig_path = os.path.join(base_folder, f"m_vs_accuracy_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
    plt.savefig(fig_path)
    plt.show()

    print(f"[INFO] m 扫描完成。汇总：{summary_path}")
    print(f"[INFO] 曲线已保存：{fig_path}")
