In [13]:
# train_cnn_gru_classifier_keras.py
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv1D, MaxPooling1D, 
                                   Dense, Dropout, BatchNormalization,
                                   Reshape, GRU)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (EarlyStopping, ReduceLROnPlateau, 
                                      ModelCheckpoint)
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm

# ---------- 中文字体设置 ----------
def set_chinese_font():
    """安全设置中文字体"""
    try:
        font_paths = ['/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
                     'C:/Windows/Fonts/simhei.ttf',
                     '/System/Library/Fonts/PingFang.ttc']
        for path in font_paths:
            if os.path.exists(path):
                fm.fontManager.addfont(path)
                plt.rcParams['font.family'] = fm.FontProperties(fname=path).get_name()
                plt.rcParams['axes.unicode_minus'] = False
                print(f"✅ 成功设置中文字体: {plt.rcParams['font.family']}")
                return True
        plt.rcParams['font.family'] = ['SimHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("⚠️ 使用系统默认中文字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False
set_chinese_font()

# ---------- 模型保存配置 ----------
MODEL_FORMAT = "keras"  # 可选: "keras" 或 "h5"
MODEL_VERSION = "v1.0"

def save_model(model, model_dir="models", model_name="cnn_gru_classifier"):
    """统一保存模型为Keras格式"""
    os.makedirs(model_dir, exist_ok=True)
    
    # 生成带时间戳的文件名
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"{model_name}_{MODEL_VERSION}_{timestamp}.keras"
    save_path = os.path.join(model_dir, filename)
    
    # 保存模型
    model.save(save_path, save_format="tf")  # 使用TensorFlow SavedModel格式保存为.keras文件
    print(f"✅ 模型已保存为Keras格式: {save_path}")
    
    # 创建latest符号链接（仅限Unix系统）
    latest_path = os.path.join(model_dir, f"{model_name}_latest.keras")
    try:
        if os.path.exists(latest_path):
            os.remove(latest_path)
        os.symlink(filename, latest_path)
        print(f"✅ 创建最新模型符号链接: {latest_path} -> {filename}")
    except Exception as e:
        print(f"⚠️ 无法创建符号链接: {e}")
    
    return save_path

# ---------- 加载数据集 ----------
def load_dataset(npz_path):
    """加载并处理数据集"""
    data = np.load(npz_path, allow_pickle=True)
    
    # 基础数据
    dataset = {
        "signals": data["signals"],
        "labels": data["labels"].astype(np.int32),
        "fs": float(data["fs"]),
        "L": int(data["L"])
    }
    
    # 处理干扰类型名称
    if "interference_type_names" in data:
        type_names = data["interference_type_names"].item() if isinstance(data["interference_type_names"], np.ndarray) else data["interference_type_names"]
    else:
        type_names = {
            "satellite_signal": "Satellite_Signal",
            "single_tone": "Single_Tone",
            "comb_spectra": "Comb_Spectra",
            "sweeping": "Sweeping-LFM",
            "pulse": "Pulse",
            "frequency_hopping": "Frequency_Hopping",
            "noise_fm": "Noise_FM",
            "noise_am": "Noise_AM",
            "random_combination": "Random_Combination"
        }
    
    # 创建标签映射
    if "type_to_label" in data:
        type2label = data["type_to_label"].item() if isinstance(data["type_to_label"], np.ndarray) else data["type_to_label"]
    else:
        type2label = {
            "satellite_signal": 0,
            "single_tone": 1,
            "comb_spectra": 2,
            "sweeping": 3,
            "pulse": 4,
            "frequency_hopping": 5,
            "noise_fm": 6,
            "noise_am": 7,
            "random_combination": 8
        }
    
    label2name = {int(i): str(type_names[k]) for k, i in type2label.items()}
    
    dataset.update({
        "type2label": type2label,
        "label2name": label2name
    })
    return dataset

# ---------- 数据预处理 ----------
def preprocess_data(dataset):
    """预处理数据"""
    # 信号标准化
    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel() 
                       for s in dataset["signals"]])
    
    # 数据集分割（保持分层抽样）
    X_train, X_test, y_train, y_test = train_test_split(
        signals, dataset["labels"],
        test_size=0.3, random_state=42, 
        stratify=dataset["labels"]
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_test, y_test,
        test_size=0.5, random_state=42,
        stratify=y_test
    )
    
    return {
        "X_train": X_train, "X_val": X_val, "X_test": X_test,
        "y_train": y_train, "y_val": y_val, "y_test": y_test,
        "label2name": dataset["label2name"],
        "L": dataset["L"]
    }

# ---------- CNN-GRU模型构建 ----------
def build_model(input_shape, num_classes):
    """构建CNN-GRU分类模型"""
    inputs = Input(shape=input_shape)
    x = Reshape((input_shape[0], 1))(inputs)  # [batch, length, 1]
    
    # CNN特征提取
    x = Conv1D(32, 7, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    
    x = Conv1D(64, 5, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    
    x = Conv1D(128, 3, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    
    # GRU时序建模
    x = GRU(64, return_sequences=True)(x)
    x = GRU(32)(x)
    
    # 分类头
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    model.compile(
        optimizer=Adam(1e-3),
        loss=SparseCategoricalCrossentropy(),
        metrics=[SparseCategoricalAccuracy()]
    )
    return model

# ---------- 数据增强 ----------
@tf.function
def aug_fn(x):
    """信号增强"""
    x = tf.cast(x, tf.float32)
    if tf.random.uniform([]) > 0.2:
        snr = tf.random.uniform([], 5., 25.)
        noise = tf.random.normal(tf.shape(x)) * tf.math.reduce_std(x) * 10.0**(-snr/20.0)
        x = x + noise
    if tf.random.uniform([]) > 0.3:
        shift = tf.random.uniform([], -100, 100, dtype=tf.int32)
        x = tf.roll(x, shift=shift, axis=0)
    if tf.random.uniform([]) > 0.3:
        scale = tf.random.uniform([], 0.7, 1.3)
        x = x * scale
    return x

# ---------- 训练函数 ----------
def train_single_model(data, model_idx=0, epochs=120, batch=128):
    """训练单个模型"""
    model = build_model((data["L"],), len(data["label2name"]))
    
    # 类别权重（处理不平衡）
    cls_weights = compute_class_weight('balanced', 
                                     classes=np.unique(data['y_train']), 
                                     y=data['y_train'])
    sample_weights = np.array([cls_weights[lab] for lab in data['y_train']])
    
    # 数据管道
    train_ds = tf.data.Dataset.from_tensor_slices(
        (data['X_train'], data['y_train'], sample_weights)
    )
    train_ds = train_ds.map(lambda x, y, w: (aug_fn(x), y, w))
    train_ds = train_ds.shuffle(10000).batch(batch).prefetch(tf.data.AUTOTUNE)
    
    val_ds = tf.data.Dataset.from_tensor_slices(
        (data['X_val'], data['y_val'])
    ).batch(batch).prefetch(tf.data.AUTOTUNE)
    
    # 回调函数
    os.makedirs("models", exist_ok=True)
    ckpt = f"models/cnn_gru_classifier_{model_idx}.keras"
    callbacks = [
        EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=20, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_sparse_categorical_accuracy', factor=0.5, patience=10),
        ModelCheckpoint(ckpt, save_best_only=True, monitor='val_sparse_categorical_accuracy', save_format="tf")
    ]
    
    print(f"\n🔥 训练模型 {model_idx + 1}...")
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    return model

# ---------- 主函数 ----------
def main():
    os.makedirs("models", exist_ok=True)
    
    # 加载数据
    print("⏳ 加载数据集...")
    dataset = load_dataset("/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz")
    data = preprocess_data(dataset)
    
    # 训练模型
    model = train_single_model(data, model_idx=0, epochs=120, batch=128)
    
    # 评估测试集
    test_loss, test_acc = model.evaluate(data['X_test'], data['y_test'], verbose=0)
    print(f"\n📊 测试集准确率: {test_acc:.4f}")
    
    # 统一保存模型
    model_path = save_model(model, model_name="cnn_gru_classifier")
    
    print("\n✅ 训练完成！模型已保存至:", model_path)

if __name__ == "__main__":
    import time
    main()

✅ 成功设置中文字体: ['WenQuanYi Micro Hei']
⏳ 加载数据集...


2025-10-21 18:06:50.864132: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 18:06:50.865746: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-21 18:06:50.866449: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus


🔥 训练模型 1...
Epoch 1/120


2025-10-21 18:06:51.594627: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_2' with dtype double and shape [56700]
	 [[{{node Placeholder/_2}}]]
2025-10-21 18:06:51.594870: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype double and shape [56700,1024]
	 [[{{node Placeholder/_0}}]]
2025-10-21 18:06:51.768722: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/spli



2025-10-21 18:07:07.269329: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [12150]
	 [[{{node Placeholder/_1}}]]
2025-10-21 18:07:07.416623: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 18:07:07.417490: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'g

Epoch 2/120
Epoch 3/120
Epoch 4/120
Epoch 5/120
Epoch 6/120
Epoch 7/120
Epoch 8/120
Epoch 9/120
Epoch 10/120
Epoch 11/120
Epoch 12/120
Epoch 13/120
Epoch 14/120
Epoch 15/120
Epoch 16/120
Epoch 17/120
Epoch 18/120
Epoch 19/120
Epoch 20/120
Epoch 21/120
Epoch 22/120
Epoch 23/120
Epoch 24/120
Epoch 25/120
Epoch 26/120
Epoch 27/120
Epoch 28/120
Epoch 29/120
Epoch 30/120
Epoch 31/120
Epoch 32/120
Epoch 33/120
Epoch 34/120
Epoch 35/120
Epoch 36/120
Epoch 37/120
Epoch 38/120
Epoch 39/120
Epoch 40/120
Epoch 41/120
Epoch 42/120
Epoch 43/120
Epoch 44/120
Epoch 45/120
Epoch 46/120
Epoch 47/120
Epoch 48/120
Epoch 49/120
Epoch 50/120
Epoch 51/120
Epoch 52/120
Epoch 53/120
Epoch 54/120
Epoch 55/120
Epoch 56/120
Epoch 57/120
Epoch 58/120
Epoch 59/120
Epoch 60/120
Epoch 61/120
Epoch 62/120
Epoch 63/120
Epoch 64/120
Epoch 65/120
Epoch 66/120
Epoch 67/120
Epoch 68/120
Epoch 69/120
Epoch 70/120
Epoch 71/120
Epoch 72/120
Epoch 73/120
Epoch 74/120
Epoch 75/120
Epoch 76/120
Epoch 77/120
Epoch 78/120
Epoch 7

2025-10-21 18:30:32.534312: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 18:30:32.535806: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-21 18:30:32.536510: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus


📊 测试集准确率: 0.8444
✅ 模型已保存为Keras格式: models/cnn_gru_classifier_v1.0_20251021_183035.keras
✅ 创建最新模型符号链接: models/cnn_gru_classifier_latest.keras -> cnn_gru_classifier_v1.0_20251021_183035.keras

✅ 训练完成！模型已保存至: models/cnn_gru_classifier_v1.0_20251021_183035.keras


In [14]:
### evaluate_cnn_gru_classifier.py
import os
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (accuracy_score, confusion_matrix, 
                            classification_report)
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.font_manager as fm

# ---------- 中文字体设置 ----------
def set_chinese_font():
    """安全设置中文字体"""
    try:
        font_paths = ['/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
                     'C:/Windows/Fonts/simhei.ttf',
                     '/System/Library/Fonts/PingFang.ttc']
        for path in font_paths:
            if os.path.exists(path):
                fm.fontManager.addfont(path)
                plt.rcParams['font.family'] = fm.FontProperties(fname=path).get_name()
                plt.rcParams['axes.unicode_minus'] = False
                print(f"✅ 成功设置中文字体: {plt.rcParams['font.family']}")
                return True
        plt.rcParams['font.family'] = ['SimHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("⚠️ 使用系统默认中文字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False
set_chinese_font()

# ---------- 加载数据集 ----------
def load_dataset(npz_path):
    """加载并处理数据集"""
    data = np.load(npz_path, allow_pickle=True)
    
    # 基础数据
    dataset = {
        "signals": data["signals"],
        "labels": data["labels"].astype(np.int32),
        "jnr_values": data["jnr_values"].astype(np.float32),
        "fs": float(data["fs"]),
        "L": int(data["L"])
    }
    
    # 处理干扰类型名称
    if "interference_type_names" in data:
        type_names = data["interference_type_names"].item() if isinstance(data["interference_type_names"], np.ndarray) else data["interference_type_names"]
    else:
        type_names = {
            "satellite_signal": "Satellite_Signal",
            "single_tone": "Single_Tone",
            "comb_spectra": "Comb_Spectra",
            "sweeping": "Sweeping-LFM",
            "pulse": "Pulse",
            "frequency_hopping": "Frequency_Hopping",
            "noise_fm": "Noise_FM",
            "noise_am": "Noise_AM",
            "random_combination": "Random_Combination"
        }
        print("⚠️ 未找到干扰类型名称，使用默认值")
    
    # 创建标签映射
    if "type_to_label" in data:
        type2label = data["type_to_label"].item() if isinstance(data["type_to_label"], np.ndarray) else data["type_to_label"]
    else:
        type2label = {
            "satellite_signal": 0,
            "single_tone": 1,
            "comb_spectra": 2,
            "sweeping": 3,
            "pulse": 4,
            "frequency_hopping": 5,
            "noise_fm": 6,
            "noise_am": 7,
            "random_combination": 8
        }
    
    label2name = {int(i): str(type_names[k]) for k, i in type2label.items()}
    
    return {
        "signals": data["signals"],
        "labels": data["labels"].astype(np.int32),
        "jnr_values": data["jnr_values"].astype(np.float32),
        "label2name": label2name,
        "L": int(data["L"])
    }

# ---------- 数据预处理 ----------
def preprocess_data(dataset):
    """预处理数据"""
    # 信号标准化
    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel() 
                       for s in dataset["signals"]])
    
    # 数据集分割（保持分层抽样）
    X_train, X_test, y_train, y_test, jnr_train, jnr_test = train_test_split(
        signals, dataset["labels"], dataset["jnr_values"],
        test_size=0.3, random_state=42, stratify=dataset["labels"]
    )
    
    return {
        "X_test": X_test,
        "y_test": y_test,
        "jnr_test": jnr_test,
        "label2name": dataset["label2name"]
    }

# ---------- 加载模型 ----------
def load_model(model_path):
    """加载训练好的CNN-GRU模型"""
    if os.path.exists(model_path):
        try:
            model = tf.keras.models.load_model(model_path)
            print(f"✅ 成功加载模型: {model_path}")
            return model
        except Exception as e:
            print(f"❌ 加载模型失败: {e}")
            return None
    print(f"⚠️ 未找到模型: {model_path}")
    return None

# ---------- 计算JNR准确率 ----------
def calculate_jnr_accuracy(y_true, y_pred, jnr_values, jnr_range):
    """计算每个JNR下的分类准确率"""
    jnr_acc = {}
    for jnr in jnr_range:
        mask = jnr_values == jnr
        if np.sum(mask) > 0:
            jnr_acc[jnr] = accuracy_score(y_true[mask], y_pred[mask])
        else:
            jnr_acc[jnr] = np.nan
    return jnr_acc

# ---------------------- 绘制混淆矩阵 ----------------------
def plot_confusion_matrix(cm, labels, title, xlabel, ylabel, filename, dpi=150, rotate_x=False):
    """专业混淆矩阵可视化"""
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_normalized = np.nan_to_num(cm_normalized)

    plt.figure(figsize=(12, 10))
    ax = sns.heatmap(cm_normalized,
                     annot=True,
                     fmt='.2f',
                     cmap='Blues',
                     xticklabels=labels,
                     yticklabels=labels,
                     square=True,
                     annot_kws={"size": 14})
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)

    plt.title(title, pad=20, fontsize=18)
    plt.xlabel(xlabel, fontsize=16)
    plt.ylabel(ylabel, fontsize=16)
    plt.xticks(rotation=45 if rotate_x else 0, 
               ha='right' if rotate_x else 'center', 
               fontsize=14)
    plt.yticks(rotation=0, fontsize=14)
    plt.tight_layout()
    plt.savefig(filename, dpi=dpi)
    plt.close()

# ---------- 绘制JNR准确率曲线 ----------
def plot_jnr_accuracy(jnr_acc, filename):
    """绘制JNR-准确率曲线"""
    plt.figure(figsize=(10, 6))
    valid_jnr = [j for j in jnr_acc if not np.isnan(jnr_acc[j])]
    valid_acc = [jnr_acc[j] for j in valid_jnr]
    
    plt.plot(valid_jnr, valid_acc, 'o-', linewidth=2, markersize=8)
    plt.xlabel('JNR (dB)', fontsize=12)
    plt.ylabel('分类准确率', fontsize=12)
    plt.title('不同JNR下的干扰识别准确率', fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.xticks(list(jnr_acc.keys()))
    plt.ylim(0, 1.05)
    plt.savefig(filename, dpi=300)
    plt.close()

# ---------- 主评估函数 ----------
def evaluate(model_path="models/cnn_gru_classifier_v1.0_20251021_183035.keras", 
             npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"):
    """主评估流程"""
    os.makedirs("visualizations", exist_ok=True)
    os.makedirs("reports", exist_ok=True)
    
    print("=" * 80)
    print("🚀 CNN-GRU干扰分类模型评估")
    print("=" * 80)

    # 1. 加载数据
    print("⏳ 加载数据集...")
    dataset = load_dataset(npz_path)
    data = preprocess_data(dataset)
    
    # 2. 加载模型
    print("⏳ 加载模型...")
    model = load_model(model_path)
    if not model:
        print("❌ 评估终止：无法加载模型")
        return
    
    # 3. 预测
    print("⏳ 进行预测...")
    y_pred = np.argmax(model.predict(data["X_test"], verbose=1), axis=1)
    
    # 4. 计算JNR准确率
    print("⏳ 计算各JNR下的准确率...")
    jnr_range = np.arange(-10, 31, 5)  # [-10, -5, 0, 5, ..., 30]
    jnr_acc = calculate_jnr_accuracy(data["y_test"], y_pred, data["jnr_test"], jnr_range)
    
    # 5. 生成混淆矩阵
    print("⏳ 生成混淆矩阵...")
    cm = confusion_matrix(data["y_test"], y_pred)
    class_names = [data["label2name"][i] for i in sorted(data["label2name"].keys())]
    plot_confusion_matrix(
        cm=cm,
        labels=class_names,
        title="CNN-GRU Classification Confusion Matrix",
        xlabel="Predicted",
        ylabel="True",
        filename="visualizations/cnn_gru_confusion_matrix.png",
        dpi=300,
        rotate_x=True
    )
    
    # 6. 绘制JNR准确率曲线
    print("⏳ 生成JNR准确率曲线...")
    plot_jnr_accuracy(jnr_acc, "visualizations/cnn_gru_jnr_accuracy.png")
    
    # 7. 保存评估报告
    print("⏳ 生成评估报告...")
    report = {
        "overall_accuracy": float(accuracy_score(data["y_test"], y_pred)),
        "jnr_accuracy": {int(j): float(acc) if not np.isnan(acc) else None 
                         for j, acc in jnr_acc.items()},
        "confusion_matrix": cm.tolist(),
        "class_names": class_names,
        "classification_report": classification_report(
            data["y_test"], y_pred, 
            target_names=class_names,
            output_dict=True
        )
    }
    
    with open("reports/cnn_gru_evaluation_report.json", "w", encoding="utf-8") as f:
        json.dump(report, f, indent=4, ensure_ascii=False)
    
    # 8. 打印结果
    print("\n" + "="*50)
    print("📊 评估结果摘要")
    print("="*50)
    print(f"整体分类准确率: {report['overall_accuracy']:.4f}")
    
    print("\n各JNR下的分类准确率:")
    for jnr in sorted(jnr_acc.keys()):
        acc = jnr_acc[jnr]
        acc_str = f"{acc:.4f}" if not np.isnan(acc) else "N/A"
        print(f"  JNR={jnr}dB: {acc_str}")
    
    print("\n✅ 评估完成！结果已保存至 reports/cnn_gru_evaluation_report.json")

# ---------- 主程序 ----------
if __name__ == "__main__":
    evaluate(
        model_path="models/cnn_gru_classifier_v1.0_20251021_183035.keras",
        npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"
    )

✅ 成功设置中文字体: ['WenQuanYi Micro Hei']
🚀 CNN-GRU干扰分类模型评估
⏳ 加载数据集...
⏳ 加载模型...


2025-10-21 18:49:41.301936: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 18:49:41.303364: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-21 18:49:41.304083: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus

✅ 成功加载模型: models/cnn_gru_classifier_v1.0_20251021_183035.keras
⏳ 进行预测...


2025-10-21 18:49:41.713176: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 18:49:41.714115: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-21 18:49:41.715069: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus

⏳ 计算各JNR下的准确率...
⏳ 生成混淆矩阵...
⏳ 生成JNR准确率曲线...
⏳ 生成评估报告...

📊 评估结果摘要
整体分类准确率: 0.8491

各JNR下的分类准确率:
  JNR=-10dB: 0.4628
  JNR=-5dB: 0.6603
  JNR=0dB: 0.8118
  JNR=5dB: 0.8840
  JNR=10dB: 0.9275
  JNR=15dB: 0.9602
  JNR=20dB: 0.9732
  JNR=25dB: 0.9803
  JNR=30dB: 0.9770

✅ 评估完成！结果已保存至 reports/cnn_gru_evaluation_report.json


CNN-GRU整体分类准确率: 0.8491

各JNR下的分类准确率:
  JNR=-10dB: 0.4628
  JNR=-5dB: 0.6603
  JNR=0dB: 0.8118
  JNR=5dB: 0.8840
  JNR=10dB: 0.9275
  JNR=15dB: 0.9602
  JNR=20dB: 0.9732
  JNR=25dB: 0.9803
  JNR=30dB: 0.9770