In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import scipy.io as sio
import h5py  # 用于v7.3格式.mat文件读取
import sys

# ================= 参数设置 =================
# 数据路径和文件
data_path = "E:/rf_datasets/"  # 存放所有.mat文件的文件夹路径

# 信号参数
SNR_dB = 10           # 信噪比，AWGN时使用
fs = 20e6             # 采样率 Hz
fc = 2.4e9            # 载波频率 Hz
v = 120               # 目标速度 m/s（多普勒计算）

# 数据预处理开关
apply_doppler = False
apply_awgn = False

# 模型超参数
raw_input_dim = 2       # 输入维度(I/Q)
model_dim = 256         # Transformer特征维度
num_heads = 4           # 注意力头数
num_layers = 2          # Transformer层数
dropout = 0.4           # dropout概率

# 训练参数
batch_size = 256
num_epochs = 100
learning_rate = 1e-4
patience = 5            # 早停耐心
n_splits = 5            # K折交叉验证折数
weight_decay = 1e-4

# ================= 多普勒和AWGN处理函数 =================
def compute_doppler_shift(v, fc):
    c = 3e8
    return (v / c) * fc

def apply_doppler_shift(signal, fd, fs):
    t = np.arange(signal.shape[-1]) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / (10**(snr_db/10))
    noise = np.sqrt(noise_power/2) * (np.random.randn(*signal.shape) + 1j*np.random.randn(*signal.shape))
    return signal + noise

# ================= 数据加载与预处理 =================
def load_and_preprocess(mat_folder, apply_doppler=False, target_velocity=30, apply_awgn=False, snr_db=20, fs=20e6, fc=2.4e9):
    mat_files = glob.glob(os.path.join(mat_folder, '*.mat'))
    print(f"共找到 {len(mat_files)} 个 .mat 文件")

    X_list = []
    y_list = []

    fd = compute_doppler_shift(target_velocity, fc)
    print(f"目标速度 {target_velocity} m/s，对应多普勒频移 {fd:.2f} Hz")

    for file in tqdm(mat_files, desc='读取与处理数据'):
        with h5py.File(file, 'r') as f:
            rfDataset = f['rfDataset']

            # 读取 dmrs，转换为复数矩阵
            dmrs_struct = rfDataset['dmrs'][:]
            dmrs_complex = dmrs_struct['real'] + 1j * dmrs_struct['imag']  # shape (2999, 288)

            # 读取 txID（uint16 转字符串）
            txID_uint16 = rfDataset['txID'][:].flatten()
            tx_id = ''.join(chr(c) for c in txID_uint16 if c != 0)

        # 数据预处理
        processed_signals = []
        for i in range(dmrs_complex.shape[0]):
            sig = dmrs_complex[i, :]
            if apply_doppler:
                sig = apply_doppler_shift(sig, fd, fs)
            if apply_awgn:
                sig = add_awgn(sig, snr_db)

            iq = np.stack((sig.real, sig.imag), axis=-1)  # (288, 2)
            processed_signals.append(iq)

        processed_signals = np.array(processed_signals)  # (2999, 288, 2)
        X_list.append(processed_signals)
        y_list.append(tx_id)

    # 标签处理
    unique_labels = sorted(list(set(y_list)))
    label_to_idx = {lab: i for i, lab in enumerate(unique_labels)}
    y_idx = np.array([label_to_idx[lab] for lab in y_list])

    # 拼接所有样本
    X_all = np.concatenate(X_list, axis=0)  # (num_files*2999, 288, 2)
    y_all = np.repeat(y_idx, dmrs_complex.shape[0])  # 每个文件对应2999个样本

    print(f"数据维度: X={X_all.shape}, y={y_all.shape}")
    print(f"类别映射: {label_to_idx}")

    return X_all, y_all, label_to_idx



# ================= 模型定义 =================
class SignalTransformer(nn.Module):
    def __init__(self, raw_input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.4):
        super(SignalTransformer, self).__init__()
        self.embedding = nn.Linear(raw_input_dim, model_dim)
        encoder_layer = TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = x[:, -1, :]  # 取最后一个时间步输出
        x = self.fc(x)
        return x

# ================= 训练辅助 =================
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def moving_average(x, w=5):
    return np.convolve(x, np.ones(w), 'valid') / w

# ================= 主训练流程 =================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] 使用设备: {device}")

    # 加载和预处理数据
    X_all, y_all, label_to_idx = load_and_preprocess(
        data_path,
        apply_doppler=apply_doppler,
        target_velocity=v,
        apply_awgn=apply_awgn,
        snr_db=SNR_dB,
        fs=fs,
        fc=fc
    )

    X_tensor = torch.tensor(X_all, dtype=torch.float32)
    y_tensor = torch.tensor(y_all, dtype=torch.long)
    full_dataset = TensorDataset(X_tensor, y_tensor)

    try:
        code_filename = os.path.basename(__file__)
    except NameError:
        code_filename = os.path.basename(sys.argv[0])  # 用命令行参数替代

    code_name = os.path.splitext(code_filename)[0]

    # 创建结果保存目录
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_name = f"{code_name}_{timestamp}_SNR{SNR_dB}dB_fd{int(compute_doppler_shift(v, fc))}_classes{len(label_to_idx)}_Transformer"
    save_folder = os.path.join(os.getcwd(), "training_results", folder_name)
    os.makedirs(save_folder, exist_ok=True)
    results_file = os.path.join(save_folder, "results.txt")

    with open(results_file, "w") as f:
        f.write(f"=== 实验总结 ===\n")
        f.write(f"时间戳: {timestamp}\n")
        f.write(f"类别数: {len(label_to_idx)}\n")
        f.write(f"SNR: {SNR_dB} dB\n")
        f.write(f"多普勒频移 fd: {compute_doppler_shift(v, fc):.2f} Hz\n")

    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = []
    test_results = []
    avg_grad_norms_per_fold = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(full_dataset)):
        print(f"\n====== Fold {fold+1}/{n_splits} ======")

        train_subset = Subset(full_dataset, train_idx)
        val_subset = Subset(full_dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=True)

        model = SignalTransformer(raw_input_dim, model_dim, num_heads, num_layers, len(label_to_idx), dropout).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        train_losses, val_losses = [], []
        train_accuracies, val_accuracies = [], []
        grad_norms = []

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(num_epochs):
            model.train()
            running_train_loss, correct_train, total_train = 0.0, 0, 0
            batch_grad_norms = []

            with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as tepoch:
                for inputs, labels in tepoch:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()

                    grad_norm = compute_grad_norm(model)
                    batch_grad_norms.append(grad_norm)

                    optimizer.step()

                    running_train_loss += loss.item()
                    _, predicted = torch.max(outputs, 1)
                    total_train += labels.size(0)
                    correct_train += (predicted == labels).sum().item()

                    tepoch.set_postfix(loss=running_train_loss / (len(train_loader)),
                                       accuracy=100 * correct_train / total_train,
                                       grad_norm=grad_norm)

            epoch_train_loss = running_train_loss / len(train_loader)
            train_losses.append(epoch_train_loss)
            train_accuracies.append(100 * correct_train / total_train)
            avg_grad_norm = np.mean(batch_grad_norms)
            grad_norms.append(avg_grad_norm)

            print(f"Epoch {epoch+1} 平均梯度范数: {avg_grad_norm:.4f}")

            # 验证阶段
            model.eval()
            running_val_loss, correct_val, total_val = 0.0, 0, 0

            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs = val_inputs.to(device)
                    val_labels = val_labels.to(device)

                    val_outputs = model(val_inputs)
                    val_loss = criterion(val_outputs, val_labels)

                    running_val_loss += val_loss.item()
                    _, val_predicted = torch.max(val_outputs, 1)
                    total_val += val_labels.size(0)
                    correct_val += (val_predicted == val_labels).sum().item()

            epoch_val_loss = running_val_loss / len(val_loader)
            val_losses.append(epoch_val_loss)
            val_accuracies.append(100 * correct_val / total_val)

            print(f"Epoch {epoch+1} 验证 Loss: {epoch_val_loss:.4f}  Acc: {100 * correct_val / total_val:.2f}%")

            # 早停判断
            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                patience_counter = 0
                best_model_wts = model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"早停，连续{patience}个epoch验证集loss未下降。")
                    break

            scheduler.step()

        fold_results.append({
            'train_loss': train_losses,
            'val_loss': val_losses,
            'train_acc': train_accuracies,
            'val_acc': val_accuracies,
            'grad_norms': grad_norms
        })
        avg_grad_norms_per_fold.append(np.mean(grad_norms))

        # 测试集上评估（这里使用验证集作为测试集示范）
        test_acc, test_cm = evaluate_model(model, val_loader, device, len(label_to_idx))
        test_results.append(test_acc)

        # 保存最佳模型权重
        torch.save(best_model_wts, os.path.join(save_folder, f"best_model_fold{fold+1}.pth"))

        # 记录结果到文件
        with open(results_file, "a") as f:
            f.write(f"\nFold {fold+1} 训练结束\n")
            f.write(f"最佳验证集Loss: {best_val_loss:.4f}\n")
            f.write(f"测试集准确率: {test_acc:.2f}%\n")

    print("\n====== 所有折训练完成 ======")
    print(f"平均测试准确率: {np.mean(test_results):.2f}% ± {np.std(test_results):.2f}%")

    with open(results_file, "a") as f:
        f.write(f"\n所有折平均测试准确率: {np.mean(test_results):.2f}% ± {np.std(test_results):.2f}%\n")

    # 绘图示例
    plot_training_curves(fold_results, save_folder)
    plot_grad_norms(avg_grad_norms_per_fold, save_folder)

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    acc = 100 * correct / total
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    print(f"测试集准确率: {acc:.2f}%")
    plot_confusion_matrix(cm, save_folder=None)  # 你也可以改成保存文件
    return acc, cm

def plot_training_curves(fold_results, save_folder):
    plt.figure(figsize=(12,5))
    for i, res in enumerate(fold_results):
        plt.plot(moving_average(res['train_loss']), label=f'Fold{i+1} Train Loss')
        plt.plot(moving_average(res['val_loss']), label=f'Fold{i+1} Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证Loss曲线')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'loss_curves.png'))
    plt.show()

def plot_grad_norms(avg_grad_norms, save_folder):
    plt.figure(figsize=(6,4))
    plt.bar(range(1, len(avg_grad_norms)+1), avg_grad_norms)
    plt.xlabel('Fold')
    plt.ylabel('平均梯度范数')
    plt.title('各Fold平均梯度范数')
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'avg_grad_norms.png'))
    plt.show()

def plot_confusion_matrix(cm, save_folder=None):
    plt.figure(figsize=(8,6))

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 中文字体
    plt.rcParams['axes.unicode_minus'] = False    # 解决负号'-'显示问题

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.ylabel('真实类别')
    plt.xlabel('预测类别')
    if save_folder:
        plt.savefig(os.path.join(save_folder, 'confusion_matrix.png'))
    plt.show()


if __name__ == "__main__":
    main()
