In [15]:
# train_mcldnn_classifier_fixed.py
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, Conv1D, MaxPooling1D,
                                     Dense, Dropout, BatchNormalization,
                                     Flatten, Reshape, LSTM, concatenate)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (EarlyStopping, ReduceLROnPlateau,
                                      ModelCheckpoint)
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from scipy.signal import hilbert
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm

# ---------- GPU 配置 ----------
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# ---------- 中文字体 ----------
def set_chinese_font():
    try:
        font_paths = ['/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
                      'C:/Windows/Fonts/simhei.ttf',
                      '/System/Library/Fonts/PingFang.ttc']
        for path in font_paths:
            if os.path.exists(path):
                fm.fontManager.addfont(path)
                plt.rcParams['font.family'] = fm.FontProperties(fname=path).get_name()
                plt.rcParams['axes.unicode_minus'] = False
                print(f"✅ 成功设置中文字体: {plt.rcParams['font.family']}")
                return True
        plt.rcParams['font.family'] = ['SimHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("⚠️ 使用系统默认中文字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False
set_chinese_font()

# ---------- 干扰类型 ----------
INTERFERENCE_TYPES = [
    "satellite_signal",  # 0
    "single_tone",
    "comb_spectra",
    "sweeping",
    "pulse",
    "frequency_hopping",
    "same_frequency",
    "noise_fm",
    "noise_am",
    "random_combination"
]
NUM_CLASSES = len(INTERFERENCE_TYPES)

# ---------- 加载数据集 ----------
def load_dataset(npz_path):
    data = np.load(npz_path, allow_pickle=True)
    if "interference_type_names" in data:
        type_names = data["interference_type_names"].item() \
            if isinstance(data["interference_type_names"], np.ndarray) \
            else data["interference_type_names"]
    else:
        type_names = {k: k.replace("_", " ").title() for k in INTERFERENCE_TYPES}
        print("⚠️ 未找到干扰类型名称，使用默认值")
    if "type_to_label" in data:
        type2label = data["type_to_label"].item() \
            if isinstance(data["type_to_label"], np.ndarray) \
            else data["type_to_label"]
    else:
        type2label = {name: i for i, name in enumerate(INTERFERENCE_TYPES)}
    return {
        "signals": data["signals"],
        "labels": data["labels"].astype(np.int32),
        "type2label": type2label,
        "label2name": {i: type_names[k] for k, i in type2label.items()},
        "L": int(data["L"])
    }

# ---------- 数据预处理 ----------
def preprocess_data(dataset):
    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel()
                        for s in dataset["signals"]], dtype=np.float32)
    L = signals.shape[1]
    signals_complex = np.zeros((signals.shape[0], 2, L), dtype=np.float32)
    for i in range(signals.shape[0]):
        analytic = hilbert(signals[i])
        signals_complex[i, 0] = signals[i]          # I
        signals_complex[i, 1] = np.imag(analytic)   # Q

    X1 = signals_complex[:, :, :, np.newaxis]       # (N, 2, L, 1)
    X2 = signals_complex[:, 0, :, np.newaxis]       # (N, L, 1)
    X3 = signals_complex[:, 1, :, np.newaxis]       # (N, L, 1)
    return {
        "X1": X1,
        "X2": X2,
        "X3": X3,
        "y": dataset["labels"],
        "label2name": dataset["label2name"],
        "L": L
    }

# ---------- MCLDNN 模型 ----------
def build_mcldnn(input_shape1=(2, 1024, 1),
                 input_shape2=(1024, 1),
                 input_shape3=(1024, 1),
                 num_classes=10):
    input1 = Input(shape=input_shape1, name='complex_input')  # (None,2,L,1)
    input2 = Input(shape=input_shape2, name='I_input')        # (None,L,1)
    input3 = Input(shape=input_shape3, name='Q_input')        # (None,L,1)

    # 分支 1：2D 卷积
    x1 = Conv2D(50, (2, 8), padding='same', activation='relu')(input1)
    x1 = BatchNormalization()(x1)
    x1 = Reshape((input_shape1[1], 100))(x1)        # (None,L,100)

    # 分支 2/3：1D 卷积
    x2 = Conv1D(50, 8, padding='same', activation='relu')(input2)
    x2 = BatchNormalization()(x2)
    x3 = Conv1D(50, 8, padding='same', activation='relu')(input3)
    x3 = BatchNormalization()(x3)

    # 合并
    x = concatenate([x1, x2, x3], axis=-1)          # (None,L,200)
    x = Conv1D(100, 5, activation='relu')(x)        # (None,L-4,100)

    # 时序
    x = LSTM(128, return_sequences=True)(x)
    x = LSTM(128)(x)

    # 分类头
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=[input1, input2, input3], outputs=outputs)
    return model

# ---------- 数据增强 ----------
@tf.function
def aug_fn(x1, x2, x3):
    """x1:(...,2,L,1)  x2/x3:(...,L,1)  支持任意前缀维度"""
    x1 = tf.cast(x1, tf.float32)
    x2 = tf.cast(x2, tf.float32)
    x3 = tf.cast(x3, tf.float32)

    # 同步噪声
    if tf.random.uniform([]) > 0.2:
        snr = tf.random.uniform([], 5., 25.)
        noise = tf.random.normal(tf.shape(x2)) * tf.math.reduce_std(x2) * (10.0 ** (-snr / 20.0))
        x2 = x2 + noise
        x3 = x3 + noise

    # 同步移位
    if tf.random.uniform([]) > 0.3:
        shift = tf.random.uniform([], -100, 100, dtype=tf.int32)
        x2 = tf.roll(x2, shift, axis=-2)
        x3 = tf.roll(x3, shift, axis=-2)

    # 同步缩放
    if tf.random.uniform([]) > 0.3:
        scale = tf.random.uniform([], 0.7, 1.3)
        x2 = x2 * scale
        x3 = x3 * scale

    # 重新拼回 (...,2,L,1)
    x1 = tf.stack([tf.squeeze(x2, axis=-1),
                   tf.squeeze(x3, axis=-1)], axis=-2)[..., tf.newaxis]
    return x1, x2, x3

# ---------- 训练 ----------
def train_single_model(data, model_idx=0, epochs=120, batch=128):
    model = build_mcldnn(input_shape1=(2, data["L"], 1),
                         input_shape2=(data["L"], 1),
                         input_shape3=(data["L"], 1),
                         num_classes=len(data["label2name"]))

    # ---- 类别权重（字典） ----
    cls_weights = compute_class_weight('balanced',
                                     classes=np.unique(data['y_train']),
                                     y=data['y_train'])
    class_weight_dict = {i: float(w) for i, w in enumerate(cls_weights)}

    # ---- 数据管道（无 sample_weight） ----
    def create_dataset(X1, X2, X3, y):
        def map_func(x, y):
            x1, x2, x3 = aug_fn(x['complex_input'], x['I_input'], x['Q_input'])
            return {'complex_input': x1, 'I_input': x2, 'Q_input': x3}, y

        ds = tf.data.Dataset.from_tensor_slices(
            ({'complex_input': X1, 'I_input': X2, 'Q_input': X3}, y))
        return ds.map(map_func, num_parallel_calls=tf.data.AUTOTUNE)

    train_ds = create_dataset(data['X1_train'], data['X2_train'], data['X3_train'], data['y_train'])
    train_ds = train_ds.shuffle(10000).batch(batch).prefetch(tf.data.AUTOTUNE)

    val_ds = create_dataset(data['X1_val'], data['X2_val'], data['X3_val'], data['y_val'])
    val_ds = val_ds.batch(batch).prefetch(tf.data.AUTOTUNE)

    # ---- 回调 ----
    os.makedirs("models", exist_ok=True)
    ckpt = f"models/mcldnn_classifier_{model_idx}.keras"
    callbacks = [
        EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=20, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_sparse_categorical_accuracy', factor=0.5, patience=10),
        ModelCheckpoint(ckpt, save_best_only=True, monitor='val_sparse_categorical_accuracy', save_format="tf")
    ]

    model.compile(optimizer=Adam(1e-3),
                loss=SparseCategoricalCrossentropy(),
                metrics=[SparseCategoricalAccuracy()])

    print(f"\n🔥 训练 MCLDNN 模型 {model_idx + 1}...")
    history = model.fit(train_ds,
                      validation_data=val_ds,
                      epochs=epochs,
                      callbacks=callbacks,
                      class_weight=class_weight_dict,   # ← 关键
                      verbose=1)
    return model

# ---------- 主函数 ----------
def main():
    os.makedirs("models", exist_ok=True)
    print("=" * 80)
    print("🚀 开始训练 MCLDNN 干扰分类模型")
    print("=" * 80)

    # 加载数据
    print("⏳ 加载数据集...")
    dataset = load_dataset("/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz")
    data = preprocess_data(dataset)

    # 数据集划分
    X1_train, X1_test, y_train, y_test = train_test_split(
        data["X1"], data["y"], test_size=0.3, random_state=42, stratify=data["y"])
    X1_val, X1_test, y_val, y_test = train_test_split(
        X1_test, y_test, test_size=0.5, random_state=42, stratify=y_test)

    # 拆分 I/Q
    X2_train, X3_train = X1_train[:, 0, :, :], X1_train[:, 1, :, :]
    X2_val, X3_val = X1_val[:, 0, :, :], X1_val[:, 1, :, :]
    X2_test, X3_test = X1_test[:, 0, :, :], X1_test[:, 1, :, :]

    train_data = {
        "X1_train": X1_train, "X2_train": X2_train, "X3_train": X3_train,
        "y_train": y_train,
        "X1_val": X1_val, "X2_val": X2_val, "X3_val": X3_val,
        "y_val": y_val,
        "X1_test": X1_test, "X2_test": X2_test, "X3_test": X3_test,
        "y_test": y_test,
        "label2name": data["label2name"],
        "L": data["L"]
    }

    # 训练
    model = train_single_model(train_data, model_idx=0, epochs=120, batch=128)

    # 测试评估
    test_loss, test_acc = model.evaluate(
        {"complex_input": X1_test, "I_input": X2_test, "Q_input": X3_test},
        y_test, verbose=0)
    print(f"\n📊 测试集准确率: {test_acc:.4f}")

    # 保存最终模型
    model.save("models/mcldnn_classifier_final.keras", save_format="tf")
    print("\n✅ 训练完成！模型已保存至 models/mcldnn_classifier_final.keras")

if __name__ == "__main__":
    main()

✅ 成功设置中文字体: ['WenQuanYi Micro Hei']
🚀 开始训练 MCLDNN 干扰分类模型
⏳ 加载数据集...


2025-10-21 20:13:30.786106: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-21 20:13:30.788660: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-21 20:13:30.789541: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus


🔥 训练 MCLDNN 模型 1...
Epoch 1/120


2025-10-21 20:13:32.328606: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype float and shape [56700,1024,1]
	 [[{{node Placeholder/_0}}]]
2025-10-21 20:13:32.328867: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype float and shape [56700,1024,1]
	 [[{{node Placeholder/_1}}]]
2025-10-21 20:13:32.555584: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split



2025-10-21 20:14:40.176986: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype float and shape [12150,1024,1]
	 [[{{node Placeholder/_0}}]]
2025-10-21 20:14:40.177198: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_3' with dtype int32 and shape [12150]
	 [[{{node Placeholder/_3}}]]
2025-10-21 20:14:40.481266: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/spli

Epoch 2/120
Epoch 3/120
Epoch 4/120
Epoch 5/120
Epoch 6/120
Epoch 7/120
Epoch 8/120
Epoch 9/120
Epoch 10/120
Epoch 11/120
Epoch 12/120
Epoch 13/120
Epoch 14/120
Epoch 15/120
Epoch 16/120
Epoch 17/120
Epoch 18/120
Epoch 19/120
Epoch 20/120
Epoch 21/120
Epoch 22/120
Epoch 23/120
Epoch 24/120
Epoch 25/120
Epoch 26/120
Epoch 27/120
Epoch 28/120
Epoch 29/120
Epoch 30/120
Epoch 31/120
Epoch 32/120
Epoch 33/120
Epoch 34/120
Epoch 35/120
Epoch 36/120
Epoch 37/120
Epoch 38/120
Epoch 39/120
Epoch 40/120
Epoch 41/120
Epoch 42/120
Epoch 43/120
Epoch 44/120
Epoch 45/120
Epoch 46/120
Epoch 47/120
Epoch 48/120
Epoch 49/120
Epoch 50/120
Epoch 51/120
Epoch 52/120
Epoch 53/120
Epoch 54/120
Epoch 55/120
Epoch 56/120
Epoch 57/120
Epoch 58/120
Epoch 59/120
Epoch 60/120
Epoch 61/120
Epoch 62/120
Epoch 63/120
Epoch 64/120
Epoch 65/120
Epoch 66/120
Epoch 67/120
Epoch 68/120
Epoch 69/120
Epoch 70/120
Epoch 71/120
Epoch 72/120
Epoch 73/120
Epoch 74/120
Epoch 75/120
Epoch 76/120
Epoch 77/120
Epoch 78/120
Epoch 7

In [3]:
"""
evaluate_mcldnn.py
MCLDNN 模型评估脚本
"""
import os
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (confusion_matrix, accuracy_score, precision_score,
                             recall_score, f1_score, mean_absolute_error)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from matplotlib import font_manager as fm
from scipy.signal import hilbert
import time

# ----------------------
# ✅ 1. 设置中文字体（可选）
# ----------------------
def set_chinese_font():
    try:
        font_paths = ['/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
                      'C:/Windows/Fonts/simhei.ttf',
                      '/System/Library/Fonts/PingFang.ttc']
        for font_path in font_paths:
            if os.path.exists(font_path):
                fm.fontManager.addfont(font_path)
                plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()
                plt.rcParams['axes.unicode_minus'] = False
                print(f"✅ 成功设置中文字体: {plt.rcParams['font.family']}")
                return True
        plt.rcParams['font.family'] = ['SimHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("⚠️ 未找到指定字体，使用默认兼容字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False
set_chinese_font()

# -------------------------------
# ✅ 2. 加载数据集
# -------------------------------
def load_dataset(npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"):
    data = np.load(npz_path, allow_pickle=True)
    signals = data["signals"]
    labels = data["labels"].astype(np.int32)
    jnr_vals = data["jnr_values"].astype(np.float32)
    fs = float(data["fs"])
    L = int(data["L"])
    metadata = data["metadata"]
    type2label = data["type_to_label"].item()
    label2type = {v: k for k, v in type2label.items()}  # 从type2label创建label2type
    type2name = data["interference_type_names"].item()
    label2name = {i: type2name[k] for k, i in type2label.items()}
    return {
        "signals": signals,
        "labels": labels,
        "jnr_values": jnr_vals,
        "fs": fs,
        "L": L,
        "metadata": metadata,
        "type2label": type2label,
        "label2type": label2type,  # 添加这个字段
        "label2name": label2name,
        "type2name": type2name
    }

# -------------------------------
# ✅ 3. 数据预处理 (与训练代码保持一致)
# -------------------------------
def preprocess_data(dataset):
    signals = dataset["signals"]
    labels = dataset["labels"]
    jnr_values = dataset["jnr_values"]
    metadata = dataset["metadata"]
    type2label = dataset["type2label"]
    label2type = dataset["label2type"]

    # 标准化
    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel() for s in signals])

    # 复信号 I/Q 构造
    signals_complex = np.zeros((signals.shape[0], 2, signals.shape[1]))
    for i in range(signals.shape[0]):
        analytic_signal = hilbert(signals[i])
        signals_complex[i, 0, :] = signals[i]            # I: 实部
        signals_complex[i, 1, :] = np.imag(analytic_signal)  # Q: 虚部

    # 模型输入1: [N, 2, 1024, 1]
    X1 = np.expand_dims(signals_complex, axis=-1)  # [N, 2, 1024, 1]

    # 模型输入2, 3: I路 [N, 1024, 1], Q路 [N, 1024, 1]
    X2 = signals_complex[:, 0, :, None]  # I: [N, 1024, 1]
    X3 = signals_complex[:, 1, :, None]  # Q: [N, 1024, 1]

    # 标签处理 (与训练代码保持一致)
    no_interference_key = "satellite_signal"
    det_labels = (labels != list(dataset["type2label"].values())[0]).astype(np.int32)

    y_type = np.array([dataset["type2label"][label2type.get(label, 0)] for label in labels])

    # 参数标签
    param_labels = []
    for m in metadata:
        p = m.get("params", {})
        start = float(p.get("start_time", 0))
        end = float(p.get("end_time", 0))
        strength = float(p.get("jnr_db", 0))
        param_labels.append([start, end, strength])
    y_param = np.array(param_labels, dtype=np.float32)

    return {
        'X1': X1,  # [N, 2, 1024, 1]
        'X2': X2,  # [N, 1024, 1]
        'X3': X3,  # [N, 1024, 1]
        'y_det': det_labels,
        'y_type': y_type,
        'y_param': y_param,
        'jnr_values': jnr_values
    }

# -------------------------------
# ✅ 4. 划分数据 (与训练代码保持一致)
# -------------------------------
def split_data(data):
    X1 = data['X1']
    X2 = data['X2']
    X3 = data['X3']
    y_det = data['y_det']
    y_type = data['y_type']
    y_param = data['y_param']
    jnr_values = data['jnr_values']

    X1_train, X1_tmp, X2_train, X2_tmp, X3_train, X3_tmp, \
    y_det_train, y_det_tmp, y_type_train, y_type_tmp, \
    y_param_train, y_param_tmp, jnr_train, jnr_tmp = train_test_split(
        X1, X2, X3, y_det, y_type, y_param, jnr_values, test_size=0.2, random_state=42
    )

    X1_val, X1_test, X2_val, X2_test, X3_val, X3_test, \
    y_det_val, y_det_test, y_type_val, y_type_test, \
    y_param_val, y_param_test, jnr_val, jnr_test = train_test_split(
        X1_tmp, X2_tmp, X3_tmp, y_det_tmp, y_type_tmp, y_param_tmp, jnr_tmp, test_size=0.5, random_state=42
    )

    return {
        'train': {'X1': X1_train, 'X2': X2_train, 'X3': X3_train, 'y_det': y_det_train, 'y_type': y_type_train, 'y_param': y_param_train, 'jnr': jnr_train},
        'val': {'X1': X1_val, 'X2': X2_val, 'X3': X3_val, 'y_det': y_det_val, 'y_type': y_type_val, 'y_param': y_param_val, 'jnr': jnr_val},
        'test': {'X1': X1_test, 'X2': X2_test, 'X3': X3_test, 'y_det': y_det_test, 'y_type': y_type_test, 'y_param': y_param_test, 'jnr': jnr_test}
    }

# ----------------------
# ✅ 5. 加载模型并预测 (修复版本)
# ----------------------
def load_and_predict(model_path, X1_test, X2_test, X3_test):
    print(f"🔁 加载模型: {model_path}")
    model = tf.keras.models.load_model(model_path)
    start_time = time.time()
    # 修复：根据实际模型输出调整
    predictions = model.predict([X1_test, X2_test, X3_test], verbose=0)
    end_time = time.time()
    prediction_time = end_time - start_time
    print(f"预测时间: {prediction_time:.2f} 秒")
    
    # 检查模型输出
    print(f"模型输出类型: {type(predictions)}")
    if isinstance(predictions, list):
        print(f"模型输出长度: {len(predictions)}")
        for i, pred in enumerate(predictions):
            print(f"  输出 {i} 形状: {pred.shape}")
            # 检查是否有无效值
            print(f"    NaN 数量: {np.sum(np.isnan(pred))}")
            print(f"    Inf 数量: {np.sum(np.isinf(pred))}")
            print(f"    范围: [{np.min(pred[np.isfinite(pred)]) if np.any(np.isfinite(pred)) else 'N/A'}, {np.max(pred[np.isfinite(pred)]) if np.any(np.isfinite(pred)) else 'N/A'}]")
    else:
        print(f"模型输出形状: {predictions.shape}")
        print(f"  NaN 数量: {np.sum(np.isnan(predictions))}")
        print(f"  Inf 数量: {np.sum(np.isinf(predictions))}")
        print(f"  范围: [{np.min(predictions[np.isfinite(predictions)]) if np.any(np.isfinite(predictions)) else 'N/A'}, {np.max(predictions[np.isfinite(predictions)]) if np.any(np.isfinite(predictions)) else 'N/A'}]")
    
    # 根据模型实际输出调整返回值
    if isinstance(predictions, list) and len(predictions) == 2:
        # 如果模型确实返回两个输出（分类和回归）
        cls_pred, reg_pred = predictions[0], predictions[1]
        
        # 处理回归输出中的无效值
        if np.any(np.isnan(reg_pred)) or np.any(np.isinf(reg_pred)):
            print("⚠️  回归输出中发现无效值，将用0替换")
            reg_pred = np.nan_to_num(reg_pred, nan=0.0, posinf=0.0, neginf=0.0)
        
        return cls_pred, reg_pred
    else:
        # 如果模型只返回分类输出
        # 创建一个默认的回归输出（或者根据需要调整）
        dummy_reg = np.zeros((predictions.shape[0], 3))  # 假设3个参数
        return predictions, dummy_reg

# ----------------------
# ✅ 6. 绘制混淆矩阵（支持归一化、x轴旋转）
# ----------------------
def plot_confusion_matrix(cm, labels, title, xlabel, ylabel, filename, dpi=150, rotate_x=False):
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_normalized = np.nan_to_num(cm_normalized)

    plt.figure(figsize=(12, 10))
    ax = sns.heatmap(cm_normalized,
                     annot=True,
                     fmt='.2f',
                     cmap='Blues',
                     xticklabels=labels,
                     yticklabels=labels,
                     square=True,
                     annot_kws={"size": 14})

    # ---------- 关键：colorbar 字体 ----------
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)   # 刻度字号
    # ----------------------------------------

    plt.title(title, pad=20, fontsize=18)
    plt.xlabel(xlabel, fontsize=16)
    plt.ylabel(ylabel, fontsize=16)
    plt.xticks(rotation=45 if rotate_x else 0, ha='right' if rotate_x else 'center', fontsize=14)
    plt.yticks(rotation=0, fontsize=14)
    plt.tight_layout()
    plt.savefig(filename, dpi=dpi)
    plt.close()

# ----------------------
# ✅ 7. 主评估函数 (修复版本)
# ----------------------
def evaluate(models_dir="models", npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"):
    os.makedirs("visualizations", exist_ok=True)
    os.makedirs("reports", exist_ok=True)

    dataset = load_dataset(npz_path)
    data = preprocess_data(dataset)
    splits = split_data(data)

    # 使用测试集数据
    X1_test = splits['test']['X1']
    X2_test = splits['test']['X2']
    X3_test = splits['test']['X3']
    y_det_test = splits['test']['y_det']
    y_type_test = splits['test']['y_type']
    y_param_test = splits['test']['y_param']
    jnr_test = splits['test']['jnr']
    
    # 获取标签名称
    label2name = dataset["label2name"]

    # 修复：使用实际保存的模型文件名
    model_path = os.path.join(models_dir, "mcldnn_classifier_final.keras")

    # 加载模型并预测
    avg_cls, avg_reg = load_and_predict(model_path, X1_test, X2_test, X3_test)
    avg_cls_labels = np.argmax(avg_cls, axis=1)

    # 计算检测标签（0表示无干扰，非0表示有干扰）
    avg_det = (avg_cls_labels != 0).astype(int)

    # 1. 检测混淆矩阵
    cm_det = confusion_matrix(y_det_test, avg_det)
    plot_confusion_matrix(
        cm=cm_det,
        labels=['No Interference', 'Interference'],
        title='MCLDNN Interference Detection Confusion Matrix',
        xlabel='Predicted',
        ylabel='True',
        filename='visualizations/MCLDNN detection_confusion_matrix.png',
        dpi=150,
        rotate_x=False
    )

    # 2. 分类混淆矩阵（x轴旋转45度）
    cm_type = confusion_matrix(y_type_test, avg_cls_labels)
    plot_confusion_matrix(
        cm=cm_type,
        labels=[label2name[i] for i in sorted(label2name.keys())],
        title='MCLDNN Classification Confusion Matrix',
        xlabel='Predicted',
        ylabel='True',
        filename='visualizations/MCLDNN interference_type_confusion_matrix.png',
        dpi=150,
        rotate_x=True
    )

    # 3. JNR vs 准确率
    wanted_jnr = np.arange(-10, 31, 5)
    jnr_acc = []
    for jnr in wanted_jnr:
        mask = jnr_test == jnr
        acc = np.nan if np.sum(mask) == 0 else accuracy_score(y_type_test[mask], avg_cls_labels[mask])
        jnr_acc.append(acc)

    plt.figure(figsize=(8, 5))
    valid_mask = ~np.isnan(jnr_acc)
    plt.plot(wanted_jnr[valid_mask], np.array(jnr_acc)[valid_mask], marker='o', linewidth=2)
    if np.any(~valid_mask):
        plt.scatter(wanted_jnr[~valid_mask], [1.0] * np.sum(~valid_mask), facecolors='none', edgecolors='r', s=60)
    plt.xlabel('JNR (dB)')
    plt.ylabel('Accuracy')
    plt.title('Classification Accuracy vs JNR')
    plt.xticks(wanted_jnr)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.savefig('visualizations/MCLDNN jnr_vs_accuracy.png', dpi=150)
    plt.close()

    # 4. 指标 & 报告
    det_acc = accuracy_score(y_det_test, avg_det)
    cls_acc = accuracy_score(y_type_test, avg_cls_labels)
    cls_precision = precision_score(y_type_test, avg_cls_labels, average='weighted')
    cls_recall = recall_score(y_type_test, avg_cls_labels, average='weighted')
    cls_f1 = f1_score(y_type_test, avg_cls_labels, average='weighted')
    
    # 只有在回归输出有效时才计算参数MAE
    param_mae = [0, 0, 0]  # 默认值
    if avg_reg is not None and avg_reg.shape[0] == y_param_test.shape[0]:
        # 检查并处理无效值
        print(f"调试信息:")
        print(f"  y_param_test 范围: [{np.min(y_param_test)}, {np.max(y_param_test)}]")
        print(f"  avg_reg 范围: [{np.min(avg_reg)}, {np.max(avg_reg)}]")
        print(f"  y_param_test 中 NaN 数量: {np.sum(np.isnan(y_param_test))}")
        print(f"  y_param_test 中 Inf 数量: {np.sum(np.isinf(y_param_test))}")
        print(f"  avg_reg 中 NaN 数量: {np.sum(np.isnan(avg_reg))}")
        print(f"  avg_reg 中 Inf 数量: {np.sum(np.isinf(avg_reg))}")
        
        # 检查是否存在无穷大或NaN值
        invalid_mask = (np.isnan(y_param_test) | np.isinf(y_param_test) | 
                        np.isnan(avg_reg) | np.isinf(avg_reg))
        
        if np.any(invalid_mask):
            print(f"⚠️  发现 {np.sum(invalid_mask)} 个无效值")
            # 创建清理后的数据
            y_param_test_clean = np.where(invalid_mask, 0, y_param_test)
            avg_reg_clean = np.where(invalid_mask, 0, avg_reg)
        else:
            y_param_test_clean = y_param_test
            avg_reg_clean = avg_reg
        
        # 限制数值范围以适应float32
        float32_max = np.finfo(np.float32).max
        y_param_test_clean = np.clip(y_param_test_clean, -float32_max, float32_max)
        avg_reg_clean = np.clip(avg_reg_clean, -float32_max, float32_max)
        
        # 计算MAE
        try:
            param_mae = mean_absolute_error(y_param_test_clean, avg_reg_clean, multioutput='raw_values')
        except Exception as e:
            print(f"⚠️  计算参数MAE时出错: {e}")
            param_mae = [0, 0, 0]  # 使用默认值

    print("\n" + "="*50)
    print("📊 评估结果")
    print("="*50)
    print(f"检测准确率: {det_acc:.4f}")
    print(f"分类准确率: {cls_acc:.4f}")
    print(f"分类精确率: {cls_precision:.4f}, 召回率: {cls_recall:.4f}, F1: {cls_f1:.4f}")
    print(f"参数 MAE: 起始时间: {param_mae[0]:.6f}s, 结束时间: {param_mae[1]:.6f}s, 强度: {param_mae[2]:.4f}dB")
    print("\nJNR 准确率:")
    for j, acc in zip(wanted_jnr, jnr_acc):
        acc_str = f"{acc:.4f}" if not np.isnan(acc) else "N/A"
        print(f"  {int(j)}dB: {acc_str}")

    report = {
        "detection_accuracy": float(det_acc),
        "classification_accuracy": float(cls_acc),
        "classification_precision": float(cls_precision),
        "classification_recall": float(cls_recall),
        "classification_f1": float(cls_f1),
        "parameter_mae": [float(m) for m in param_mae],
        "jnr_accuracies": {f"{int(j)}dB": float(acc) if not np.isnan(acc) else None for j, acc in zip(wanted_jnr, jnr_acc)}
    }
    with open("reports/MCLDNN_evaluation_report.json", "w", encoding='utf-8') as f:
        json.dump(report, f, indent=4, ensure_ascii=False)

    print("\n✅ 评估完成！混淆矩阵与报告已保存。")

# ----------------------
# ✅ 8. 主函数
# ----------------------
if __name__ == "__main__":
    evaluate(models_dir="models", npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz")

✅ 成功设置中文字体: ['WenQuanYi Micro Hei']
🔁 加载模型: models/mcldnn_classifier_final.keras


2025-10-23 15:14:03.739042: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32
	 [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]
2025-10-23 15:14:03.739950: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32
	 [[{{node gradients/split_grad/concat/split/split_dim}}]]
2025-10-23 15:14:03.741100: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You mus

预测时间: 10.53 秒
模型输出类型: <class 'numpy.ndarray'>
模型输出形状: (8100, 9)
  NaN 数量: 0
  Inf 数量: 0
  范围: [0.0, 1.0]
调试信息:
  y_param_test 范围: [-inf, 30.0]
  avg_reg 范围: [0.0, 0.0]
  y_param_test 中 NaN 数量: 0
  y_param_test 中 Inf 数量: 906
  avg_reg 中 NaN 数量: 0
  avg_reg 中 Inf 数量: 0
⚠️  发现 906 个无效值

📊 评估结果
检测准确率: 0.9585
分类准确率: 0.8974
分类精确率: 0.9050, 召回率: 0.8974, F1: 0.8978
参数 MAE: 起始时间: 0.126749s, 结束时间: 0.763075s, 强度: 11.9580dB

JNR 准确率:
  -10dB: 0.5164
  -5dB: 0.7746
  0dB: 0.9163
  5dB: 0.9484
  10dB: 0.9664
  15dB: 0.9761
  20dB: 0.9889
  25dB: 0.9838
  30dB: 0.9891

✅ 评估完成！混淆矩阵与报告已保存。
