In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import scipy.io as sio
from  data_utilities import *
import h5py  # 用于v7.3格式.mat文件读取
import sys
from sklearn.model_selection import train_test_split


# ================= 参数设置 =================
# 数据路径和文件
data_path = "E:/rf_datasets/"  # 存放所有.mat文件的文件夹路径

# 信号参数
SNR_dB = 10           # 信噪比，AWGN时使用
fs = 20e6             # 采样率 Hz
fc = 2.4e9            # 载波频率 Hz
v = 120               # 目标速度 m/s（多普勒计算）

# 数据预处理开关
apply_doppler = False
apply_awgn = False

# 模型超参数
raw_input_dim = 2       # 输入维度(I/Q)
model_dim = 64         # Transformer特征维度
num_heads = 4           # 注意力头数
num_layers = 1          # Transformer层数
dropout = 0.1           # dropout概率

# 训练参数
batch_size = 128
num_epochs = 300
learning_rate = 1e-4
patience = 5            # 早停耐心
n_splits = 5            # K折交叉验证折数
weight_decay = 1e-4

# ================= 多普勒和AWGN处理函数 =================
def compute_doppler_shift(v, fc):
    c = 3e8
    return (v / c) * fc

def apply_doppler_shift(signal, fd, fs):
    t = np.arange(signal.shape[-1]) / fs
    doppler_phase = np.exp(1j * 2 * np.pi * fd * t)
    return signal * doppler_phase

def add_awgn(signal, snr_db):
    signal_power = np.mean(np.abs(signal)**2)
    noise_power = signal_power / (10**(snr_db/10))
    noise = np.sqrt(noise_power/2) * (np.random.randn(*signal.shape) + 1j*np.random.randn(*signal.shape))
    return signal + noise

# ================= 数据加载与预处理 =================
def load_and_preprocess_with_grouping(
    mat_folder,
    group_size=288,
    apply_doppler=False,
    target_velocity=30,
    apply_awgn=False,
    snr_db=20,
    fs=20e6,
    fc=2.4e9,
):
    # 读取所有mat文件
    mat_files = glob.glob(os.path.join(mat_folder, '*.mat'))
    print(f"共找到 {len(mat_files)} 个 .mat 文件")

    # 计算多普勒频移（如果需要）
    fd = compute_doppler_shift(target_velocity, fc)
    print(f"目标速度 {target_velocity} m/s，对应多普勒频移 {fd:.2f} Hz")

    # 存放每个文件的信号和对应标签
    X_files = []
    y_files = []
    label_set = set()

    # 读取文件数据
    for file in tqdm(mat_files, desc='读取数据'):
        with h5py.File(file, 'r') as f:
            rfDataset = f['rfDataset']
            dmrs_struct = rfDataset['dmrs'][:]
            dmrs_complex = dmrs_struct['real'] + 1j * dmrs_struct['imag']

            txID_uint16 = rfDataset['txID'][:].flatten()
            tx_id = ''.join(chr(c) for c in txID_uint16 if c != 0)

            processed_signals = []
            for i in range(dmrs_complex.shape[0]):
                sig = dmrs_complex[i, :]
                if apply_doppler:
                    sig = apply_doppler_shift(sig, fd, fs)
                if apply_awgn:
                    sig = add_awgn(sig, snr_db)

                iq = np.stack((sig.real, sig.imag), axis=-1)  # (288, 2)
                processed_signals.append(iq)
            processed_signals = np.array(processed_signals)  # (num_samples, 288, 2)

            X_files.append(processed_signals)
            y_files.append(tx_id)
            label_set.add(tx_id)

    label_list = sorted(list(label_set))
    label_to_idx = {label: i for i, label in enumerate(label_list)}

    X_all_list = []
    y_all_list = []

    for label in label_list:
        # 找出该类别对应的所有文件索引
        files_idx = [i for i, y in enumerate(y_files) if y == label]
        num_files = len(files_idx)
        if num_files == 0:
            continue

        samples_per_file = group_size // num_files  # 每文件取多少条样本
        if samples_per_file == 0:
            print(f"[WARN] 类别 {label} 文件数量过多，导致每文件样本数为0，跳过该类别")
            continue

        # 计算该类别能组成多少个完整的group（每group取每文件samples_per_file条数据）
        min_samples = min([X_files[i].shape[0] for i in files_idx])
        max_groups = min_samples // samples_per_file
        if max_groups == 0:
            print(f"[WARN] 类别 {label} 样本不足，跳过")
            continue

        # 按组循环
        for group_i in range(max_groups):
            # 取每个文件连续samples_per_file条数据
            pieces = []
            for fi in files_idx:
                start = group_i * samples_per_file
                end = start + samples_per_file
                piece = X_files[fi][start:end]  # (samples_per_file, 288, 2)
                pieces.append(piece)

            # 拼接成 (num_files * samples_per_file, 288, 2) == (288, 288, 2)
            big_block = np.concatenate(pieces, axis=0)

            # 转置采样点和样本条数维度：(288, 288, 2) -> (288, 288, 2)
            big_block_t = np.transpose(big_block, (1, 0, 2))

            # big_block_t是288条样本，每条样本长度288
            X_all_list.append(big_block_t)  # (288, 288, 2)
            y_all_list.append(np.full(group_size, label_to_idx[label], dtype=np.int64))

    # 合并所有类别数据
    X_all = np.concatenate(X_all_list, axis=0)  # (总样本数, 288, 2)
    y_all = np.concatenate(y_all_list, axis=0)  # (总样本数, )

    print(f"[INFO] 处理后样本数: {X_all.shape[0]}, 每样本长度: {X_all.shape[1]}")

    return X_all, y_all, label_to_idx

def split_by_rx_id(X_all, y_all, rx_id_all, test_rx_id_str):
    test_mask = (rx_id_all == test_rx_id_str)
    trainval_mask = ~test_mask

    X_trainval = X_all[trainval_mask]
    y_trainval = y_all[trainval_mask]

    X_test = X_all[test_mask]
    y_test = y_all[test_mask]

    print(f"[INFO] 按接收机划分测试集 rxID={test_rx_id_str}")
    print(f"训练+验证集样本数: {X_trainval.shape[0]}, 测试集样本数: {X_test.shape[0]}")

    return X_trainval, y_trainval, X_test, y_test


# ================= 模型定义 =================
class SignalTransformer(nn.Module):
    def __init__(self, raw_input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.4):
        super(SignalTransformer, self).__init__()
        self.embedding = nn.Linear(raw_input_dim, model_dim)
        encoder_layer = TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(model_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = x[:, -1, :]  # 取最后一个时间步输出
        x = self.fc(x)
        return x

# ================= 训练辅助 =================
def compute_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

def moving_average(x, w=5):
    return np.convolve(x, np.ones(w), 'valid') / w

# ================= 主训练流程 =================
# ================= 主训练流程 =================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] 使用设备: {device}")

    # 1. 加载数据
    X_all, y_all, label_to_idx = load_and_preprocess_with_grouping(
        data_path,
        group_size=288,
        apply_doppler=apply_doppler,
        target_velocity=v,
        apply_awgn=apply_awgn,
        snr_db=SNR_dB,
        fs=fs,
        fc=fc
    )

    # 2. 创建结果保存目录
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    script_name = "LTE-V_cross_fluently"
    folder_name = f"{timestamp}_{script_name}_SNR{SNR_dB}dB_fd{int(compute_doppler_shift(v, fc))}_classes_{len(label_to_idx)}_Transformer"
    save_folder = os.path.join(os.getcwd(), "training_results", folder_name)
    os.makedirs(save_folder, exist_ok=True)
    results_file = os.path.join(save_folder, "results.txt")

    # 3. 划分训练和测试集
    X_trainval, X_test, y_trainval, y_test = train_test_split(
        X_all, y_all, test_size=0.25, stratify=y_all, random_state=42
    )

    # 4. 构造 TensorDataset 和 DataLoader
    full_dataset = TensorDataset(torch.tensor(X_trainval, dtype=torch.float32),
                                 torch.tensor(y_trainval, dtype=torch.long))
    test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                                torch.tensor(y_test, dtype=torch.long))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    print(f"[INFO] 最终训练集: {X_trainval.shape}, 测试集: {X_test.shape}")

    # Step 2: KFold
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    indices = np.arange(len(full_dataset))

    fold_results = []
    val_results = []
    final_test_results = []
    avg_grad_norms_per_fold = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
        print(f"\n====== Fold {fold+1}/{n_splits} ======")

        train_subset = Subset(full_dataset, train_idx)
        val_subset = Subset(full_dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=False)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, drop_last=False)

        model = SignalTransformer(raw_input_dim, model_dim, num_heads, num_layers,
                                  len(label_to_idx), dropout).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

        train_losses, val_losses = [], []
        train_accuracies, val_accuracies = [], []
        grad_norms = []

        best_val_acc = 0.0
        patience_counter = 0
        best_model_wts = None

        for epoch in range(num_epochs):
            model.train()
            running_train_loss, correct_train, total_train = 0.0, 0, 0
            batch_grad_norms = []

            with tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1}/{num_epochs}", unit="batch") as tepoch:
                for inputs, labels in tepoch:
                    inputs, labels = inputs.to(device), labels.to(device)

                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()

                    grad_norm = compute_grad_norm(model)
                    batch_grad_norms.append(grad_norm)

                    optimizer.step()

                    running_train_loss += loss.item()
                    _, predicted = torch.max(outputs, 1)
                    total_train += labels.size(0)
                    correct_train += (predicted == labels).sum().item()

                    tepoch.set_postfix(loss=running_train_loss / len(train_loader),
                                       accuracy=100 * correct_train / total_train,
                                       grad_norm=grad_norm)

            epoch_train_loss = running_train_loss / len(train_loader)
            train_losses.append(epoch_train_loss)
            train_accuracies.append(100 * correct_train / total_train)
            avg_grad_norm = np.mean(batch_grad_norms)
            grad_norms.append(avg_grad_norm)

            # 验证阶段
            model.eval()
            correct_val, total_val = 0, 0

            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    val_outputs = model(val_inputs)
                    _, val_predicted = torch.max(val_outputs, 1)
                    total_val += val_labels.size(0)
                    correct_val += (val_predicted == val_labels).sum().item()

            val_acc = 100 * correct_val / total_val
            val_accuracies.append(val_acc)

            print(f"Epoch {epoch+1} 验证集准确率: {val_acc:.2f}%")

            # 早停判断（以 val_acc 为准）
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                best_model_wts = model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"早停，连续 {patience} 个 epoch 验证集准确率未提升。")
                    break

            scheduler.step()

        # 保存训练曲线
        fold_results.append({
            'train_loss': train_losses,
            'val_acc': val_accuracies,
            'train_acc': train_accuracies,
            'grad_norms': grad_norms
        })
        avg_grad_norms_per_fold.append(np.mean(grad_norms))

        # 恢复最佳权重
        model.load_state_dict(best_model_wts)

        # 验证集评估
        val_acc, val_cm = evaluate_model(model, val_loader, device, len(label_to_idx))
        val_results.append(val_acc)
        plot_confusion_matrix(val_cm, save_path=os.path.join(save_folder, f"confusion_matrix_val_fold{fold+1}.png"))

        # 测试集评估
        test_acc, test_cm = evaluate_model(model, test_loader, device, len(label_to_idx))
        final_test_results.append(test_acc)
        plot_confusion_matrix(test_cm, save_path=os.path.join(save_folder, f"confusion_matrix_test_fold{fold+1}.png"))

        print(f"Fold {fold+1} 验证集Acc: {val_acc:.2f}% | 测试集Acc: {test_acc:.2f}%")

        # 保存最佳模型权重
        torch.save(best_model_wts, os.path.join(save_folder, f"best_model_fold{fold+1}.pth"))

        # 保存结果及超参数
        with open(results_file, "a") as f:
            f.write(f"\n训练结束\n")
            f.write(f"超参数设置:\n")
            f.write(f"  batch_size: {batch_size}\n")
            f.write(f"  learning_rate: {learning_rate}\n")
            f.write(f"  weight_decay: {weight_decay}\n")
            f.write(f"  num_epochs: {num_epochs}\n")
            f.write(f"  model_dim: {model_dim}\n")
            f.write(f"  num_heads: {num_heads}\n")
            f.write(f"  num_layers: {num_layers}\n")
            f.write(f"  dropout: {dropout}\n")
            f.write(f"  patience: {patience}\n")
            f.write(f"最佳验证集Acc: {best_val_acc:.2f}%\n")
            f.write(f"测试集准确率: {test_acc:.2f}%\n")

    print(f"最佳验证集Acc={best_val_acc:.2f}%, 测试集Acc={test_acc:.2f}%")

    print("\n====== 所有折训练完成 ======")
    print(f"平均验证集准确率: {np.mean(val_results):.2f}% ± {np.std(val_results):.2f}%")
    print(f"独立测试集平均准确率: {np.mean(final_test_results):.2f}% ± {np.std(final_test_results):.2f}%")

    with open(results_file, "a") as f:
        f.write(f"\n所有折平均验证集准确率: {np.mean(val_results):.2f}% ± {np.std(val_results):.2f}%\n")
        f.write(f"所有折平均测试集准确率: {np.mean(final_test_results):.2f}% ± {np.std(final_test_results):.2f}%\n")

    # 绘制训练曲线和梯度范数
    plot_training_curves(fold_results, save_folder)
    plot_grad_norms(avg_grad_norms_per_fold, save_folder)



def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    correct, total = 0, 0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    acc = 100 * correct / total
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    return acc, cm

def plot_training_curves(fold_results, save_folder):
    plt.figure(figsize=(12,5))
    for i, res in enumerate(fold_results):
        plt.plot(moving_average(res['train_loss']), label=f'Fold{i+1} Train Loss')
        plt.plot(moving_average(res['val_loss']), label=f'Fold{i+1} Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证Loss曲线')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'loss_curves.png'))
    plt.show()

def plot_grad_norms(avg_grad_norms, save_folder):
    plt.figure(figsize=(6,4))
    plt.bar(range(1, len(avg_grad_norms)+1), avg_grad_norms)
    plt.xlabel('Fold')
    plt.ylabel('平均梯度范数')
    plt.title('各Fold平均梯度范数')
    plt.grid()
    plt.savefig(os.path.join(save_folder, 'avg_grad_norms.png'))
    plt.show()

def plot_confusion_matrix(cm, save_path=None):
    plt.figure(figsize=(8,6))
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 中文字体
    plt.rcParams['axes.unicode_minus'] = False    # 解决负号显示问题

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.ylabel('真实类别')
    plt.xlabel('预测类别')

    if save_path:
        plt.savefig(save_path)
    plt.show()



if __name__ == "__main__":
    main()
