In [None]:
# ResNet 1D 自动 block_size 循环训练脚本（固定 20dB）
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold, train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import h5py
from data_utilities import *

# ================= 参数设置 =================
data_path = "D:/users/zhongyuan/rf_datasets"  # 数据文件夹
fs = 5e6
fc = 5.9e9
v = 120
apply_doppler = True
apply_awgn = True

# 模型超参数
batch_size = 64
num_epochs = 300
learning_rate = 1e-4
weight_decay = 1e-3
in_planes = 64
dropout = 0.5
patience = 5
n_splits = 5

# ================= 多普勒和AWGN处理函数 =================
def compute_doppler_shift(v, fc):
    c = 3e8
    v = v/3.6
    return (v / c) * fc

def apply_doppler_shift(signal, fd, fs):
    t = np.arange(signal.shape[-1]) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / (10**(snr_db/10))
    noise_real = np.random.randn(*signal.shape)
    noise_imag = np.random.randn(*signal.shape)
    noise = np.sqrt(noise_power/2) * (noise_real + 1j*noise_imag)
    return signal + noise

# ================= 数据加载（按 block 保存，并翻转 block） =================
def load_and_preprocess_with_grouping(mat_folder, group_size=288, apply_doppler=False,
                                      target_velocity=30, apply_awgn=False, snr_db=20,
                                      fs=5e6, fc=5.9e9):
    """
    改动说明：
      - 每个 big_block 保持为整体 (group_size, sample_len, 2)
      - 返回 X_blocks: shape (num_blocks, sample_len, group_size, 2)
      - 每个 block 内进行翻转，使得每条新“样本”对应原 DMRS 在同一采样点的 IQ
      - 返回 y_blocks: shape (num_blocks,) 对应每个 block 的 label（单个整数）
    """
    mat_files = glob.glob(os.path.join(mat_folder, '*.mat'))
    print(f"共找到 {len(mat_files)} 个 .mat 文件")
    fd = compute_doppler_shift(target_velocity, fc)
    print(f"目标速度 {target_velocity} km/h，多普勒频移 {fd:.2f} Hz")
    
    X_files, y_files, label_set = [], [], set()
    for file in tqdm(mat_files, desc='读取数据'):
        with h5py.File(file, 'r') as f:
            rfDataset = f['rfDataset']
            dmrs_struct = rfDataset['dmrs'][:]
            dmrs_complex = dmrs_struct['real'] + 1j * dmrs_struct['imag']
            txID_uint16 = rfDataset['txID'][:].flatten()
            tx_id = ''.join(chr(c) for c in txID_uint16 if c != 0)
            
            processed_signals = []
            for i in range(dmrs_complex.shape[0]):
                sig = dmrs_complex[i, :]
                # === step1: 原始信号功率归一化 ===
                sig = sig / (np.sqrt(np.mean(np.abs(sig)**2)) + 1e-12)

                # === step2: Doppler （只改变相位，不改变功率） ===
                if apply_doppler:
                    sig = apply_doppler_shift(sig, fd, fs)

                # === step3: AWGN（严格按照 SNR 产生噪声） ===
                if apply_awgn:
                    sig = add_awgn(sig, snr_db)

                iq = np.stack((sig.real, sig.imag), axis=-1)  # (sample_len, 2)
                processed_signals.append(iq)
            processed_signals = np.array(processed_signals)  # (num_samples_file, sample_len, 2)
            X_files.append(processed_signals)
            y_files.append(tx_id)
            label_set.add(tx_id)
    
    label_list = sorted(list(label_set))
    label_to_idx = {label: i for i, label in enumerate(label_list)}
    X_blocks_list = []   # 每个元素是一个 big_block (sample_len, group_size, 2)
    y_blocks_list = []   # 每个元素是单个 label idx
    
    for label in label_list:
        files_idx = [i for i, y in enumerate(y_files) if y == label]
        num_files = len(files_idx)
        if num_files == 0:
            continue
        samples_per_file = group_size // num_files
        if samples_per_file == 0:
            print(f"[WARN] 类别 {label} 文件数量过多，导致每文件样本数为0，跳过该类别")
            continue
        min_samples = min([X_files[i].shape[0] for i in files_idx])
        max_groups = min_samples // samples_per_file
        if max_groups == 0:
            print(f"[WARN] 类别 {label} 样本不足，跳过")
            continue
        
        for group_i in range(max_groups):
            pieces = []
            for fi in files_idx:
                start = group_i * samples_per_file
                end = start + samples_per_file
                piece = X_files[fi][start:end]  # (samples_per_file, sample_len, 2)
                pieces.append(piece)
            
            # big_block shape: (group_size, sample_len, 2)
            big_block = np.concatenate(pieces, axis=0)
            
            # 翻转 block：每条新样本对应同一采样点的 IQ
            big_block = np.transpose(big_block, (1, 0, 2))  # (sample_len, group_size, 2)
            
            X_blocks_list.append(big_block)
            y_blocks_list.append(label_to_idx[label])
    
    if len(X_blocks_list) == 0:
        raise RuntimeError("没有生成任何 block，请检查数据/group_size 设置")
    
    X_blocks = np.stack(X_blocks_list, axis=0)  # (num_blocks, sample_len, group_size, 2)
    y_blocks = np.array(y_blocks_list, dtype=np.int64)  # (num_blocks,)
    
    print(f"[INFO] 生成 block 数: {X_blocks.shape[0]}, 每 block 样本数: {X_blocks.shape[2]}, 每样本长度: {X_blocks.shape[1]}")
    return X_blocks, y_blocks, label_to_idx


# ================= 1D ResNet18（增加 dropout 和 in_planes） =================
class BasicBlock1D(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None, dropout=0.0):
        super().__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class ResNet18_1D(nn.Module):
    def __init__(self, num_classes=10, in_planes=64, dropout=0.0):
        super().__init__()
        self.in_planes = in_planes
        self.conv1 = nn.Conv1d(2, in_planes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 2, stride=1, dropout=dropout)
        self.layer2 = self._make_layer(128, 2, stride=2, dropout=dropout)
        self.layer3 = self._make_layer(256, 2, stride=2, dropout=dropout)
        self.layer4 = self._make_layer(512, 2, stride=2, dropout=dropout)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, blocks, stride, dropout):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes)
            )
        layers = [BasicBlock1D(self.in_planes, planes, stride, downsample, dropout)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock1D(self.in_planes, planes, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, sample_len, 2) -> (B, 2, sample_len)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.squeeze(-1)
        return self.fc(x)


# ================= 辅助函数（保持不变） =================
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += (p.grad.data.norm(2).item()) ** 2
    return total_norm ** 0.5

def moving_average(x, w=5):
    x = np.array(x)
    if len(x) == 0:
        return np.array([])
    if w <= 0:
        w = 1
    if len(x) < w:
        w = len(x)
    return np.convolve(x, np.ones(w), 'valid') / w

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    acc = 100 * correct / total
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    return acc, cm

def plot_training_curves(fold_results, save_folder):
    plt.figure(figsize=(12,5))
    for i, res in enumerate(fold_results):
        plt.plot(moving_average(res['train_loss']), label=f'Fold{i+1} Train Loss')
        plt.plot(moving_average(res['val_loss']), label=f'Fold{i+1} Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证Loss曲线')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'loss_curves.png'))
    plt.close()

def plot_grad_norms(avg_grad_norms, save_folder):
    plt.figure(figsize=(6,4))
    plt.bar(range(1, len(avg_grad_norms)+1), avg_grad_norms)
    plt.xlabel('Fold')
    plt.ylabel('平均梯度范数')
    plt.title('各Fold平均梯度范数')
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'avg_grad_norms.png'))
    plt.close()

def plot_confusion_matrix(cm, save_path=None):
    plt.figure(figsize=(8,6))
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('Reference')
    plt.xlabel('Predicted')
    if save_path:
        plt.savefig(save_path)
    plt.close()

def check_block_overlap(train_blocks_idx, val_blocks_idx, test_blocks_idx):
    train_set = set(train_blocks_idx)
    val_set = set(val_blocks_idx)
    test_set = set(test_blocks_idx)

    overlap_train_val = train_set & val_set
    overlap_train_test = train_set & test_set
    overlap_val_test = val_set & test_set

    if overlap_train_val or overlap_train_test or overlap_val_test:
        raise RuntimeError(f"[ERROR] Block 重叠检测失败！"
                           f"\nTrain-Val overlap: {overlap_train_val}"
                           f"\nTrain-Test overlap: {overlap_train_test}"
                           f"\nVal-Test overlap: {overlap_val_test}")
    else:
        print("[INFO] Block 重叠检查通过，训练/验证/测试 block 互斥。")

# ================= 主训练函数（无 K 折，单次训练） =================
def train_for_snr(SNR_dB, save_folder, results_file, group_size=288):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] 使用设备: {device}")

    # 1) 加载 block
    X_blocks, y_blocks, label_to_idx = load_and_preprocess_with_grouping(
        data_path,
        group_size=group_size,
        apply_doppler=apply_doppler,
        target_velocity=v,
        apply_awgn=apply_awgn,
        snr_db=SNR_dB,
        fs=fs,
        fc=fc
    )

    num_blocks = X_blocks.shape[0]
    print(f"[INFO] 总 block 数: {num_blocks}")

    # 2) block 级三划分：train / val / test
    block_idx = np.arange(num_blocks)
    train_val_idx, test_idx, y_train_val, y_test = train_test_split(
        block_idx,
        y_blocks,
        test_size=0.25,
        stratify=y_blocks,
        random_state=42
    )

    train_idx, val_idx, y_train, y_val = train_test_split(
        train_val_idx,
        y_train_val,
        test_size=0.25,
        stratify=y_train_val,
        random_state=42
    )

    check_block_overlap(train_idx, val_idx, test_idx)

    # 3) 展开 block → sample
    def expand_blocks(Xb, yb):
        X = Xb.reshape(-1, Xb.shape[2], Xb.shape[3])
        y = np.repeat(yb, Xb.shape[1])
        return X, y

    X_train, y_train = expand_blocks(X_blocks[train_idx], y_blocks[train_idx])
    X_val, y_val = expand_blocks(X_blocks[val_idx], y_blocks[val_idx])
    X_test, y_test = expand_blocks(X_blocks[test_idx], y_blocks[test_idx])

    train_loader = DataLoader(
        TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                      torch.tensor(y_train, dtype=torch.long)),
        batch_size=batch_size,
        shuffle=True
    )

    val_loader = DataLoader(
        TensorDataset(torch.tensor(X_val, dtype=torch.float32),
                      torch.tensor(y_val, dtype=torch.long)),
        batch_size=batch_size,
        shuffle=False
    )

    test_loader = DataLoader(
        TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                      torch.tensor(y_test, dtype=torch.long)),
        batch_size=batch_size,
        shuffle=False
    )

    # 4) 模型与优化器
    model = ResNet18_1D(
        num_classes=len(label_to_idx),
        in_planes=in_planes,
        dropout=dropout
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=10,
        gamma=0.5
    )

    best_val_acc = 0.0
    patience_counter = 0
    best_model_wts = None

    train_losses, val_losses, grad_norms = [], [], []

    # ================= Epoch 循环（完全保留） =================
    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        batch_grad_norms = []

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            batch_grad_norms.append(compute_grad_norm(model))
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        avg_grad_norm = np.mean(batch_grad_norms)

        train_losses.append(train_loss)
        grad_norms.append(avg_grad_norm)

        # ===== 验证 =====
        model.eval()
        val_loss_sum, correct, total = 0.0, 0, 0
        all_labels, all_preds = [], []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss_sum += loss.item()
                _, preds = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        val_loss = val_loss_sum / len(val_loader)
        val_acc = 100 * correct / total
        val_losses.append(val_loss)
        val_cm = confusion_matrix(all_labels, all_preds)

        log = (f"Epoch {epoch+1}: "
               f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
               f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%, "
               f"Grad Norm={avg_grad_norm:.4f}")
        print(log)
        with open(results_file, "a") as f:
            f.write(log + "\n")

        if val_acc > best_val_acc + 0.01:
            best_val_acc = val_acc
            patience_counter = 0
            best_model_wts = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("[INFO] Early stopping triggered")
                break

        scheduler.step()

    # ================= 测试集评估 =================
    model.load_state_dict(best_model_wts)
    test_acc, test_cm = evaluate_model(model, test_loader, device, len(label_to_idx))
    
    print(f"\n Test Acc: {test_acc:.2f}% ")
    with open(results_file, "a") as f:
        f.write(f"\n Test Acc: {test_acc:.2f}% ")

    plot_confusion_matrix(val_cm, os.path.join(save_folder, "confusion_matrix_val.png"))
    plot_confusion_matrix(test_cm, os.path.join(save_folder, "confusion_matrix_test.png"))
    torch.save(best_model_wts, os.path.join(save_folder, "best_model.pth"))

    plot_training_curves(
        [{'train_loss': train_losses, 'val_loss': val_losses}],
        save_folder
    )
    plot_grad_norms([np.mean(grad_norms)], save_folder)

    return test_acc

# ================= 主训练函数保持不变 =================
# train_for_snr 函数保持不变，只是 SNR 固定 20 dB
def train_for_block_size(block_size, save_folder, results_file):
    SNR_dB = 20
    return train_for_snr(SNR_dB, save_folder, results_file, group_size=block_size)

# ================= block_size 循环训练 =================
if __name__ == "__main__":
    block_sizes = list(range(8, 576, 16))  # 40,80,...,240
    block_accs = []

    for group_size in block_sizes:
        print(f"\n\n================== 当前实验 block_size={group_size} ==================\n")
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        script_name = "LTE-V_XFR_Block"
        label_to_idx = 9
        folder_name = f"{timestamp}_{script_name}_block{group_size}_SNR20dB_fd{int(compute_doppler_shift(v, fc))}_classes_{label_to_idx}_ResNet"
        save_folder = os.path.join(os.getcwd(), "training_results", folder_name)
        os.makedirs(save_folder, exist_ok=True)
        results_file = os.path.join(save_folder, "results.txt")
        with open(results_file, "a") as f:
            f.write(f"\n================ block_size={group_size} =================\n")
        test_acc = train_for_block_size(group_size, save_folder, results_file)
        block_accs.append(test_acc)
        print(f"block_size {group_size} → results in: {save_folder}")

    # 绘制 block_size vs 测试准确率曲线
    plt.figure(figsize=(8,5))
    plt.plot(block_sizes, block_accs, marker='o', linestyle='-', color='b')
    plt.xlabel("Block Size (m)")
    plt.ylabel("Accuracy of test data (%)")
    plt.title("Block Size vs Accuracy of test data")
    plt.grid(True)
    block_curve_path = os.path.join(os.getcwd(), "training_results", f"BlockSize_vs_accuracy.png")
    plt.savefig(block_curve_path)
    plt.show()
    print(f"[INFO] Block Size vs 测试准确率曲线已保存到 {block_curve_path}")