In [1]:
# PLE_final_1019.py
import os
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (confusion_matrix, accuracy_score, precision_score,
                             recall_score, f1_score, mean_absolute_error)
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv1D, BatchNormalization, MaxPooling1D,
                                     GlobalAveragePooling1D, Dense, Reshape, Layer)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.losses import BinaryCrossentropy, SparseCategoricalCrossentropy, MeanSquaredError
from tensorflow.keras.metrics import BinaryAccuracy, SparseCategoricalAccuracy
import time

# ---------------------- 中文字体 ----------------------
def set_chinese_font():
    try:
        plt.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("✅ 成功设置中文字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False

set_chinese_font()

# ---------------------- 自定义 PLE 层 ----------------------
class PLELayer(Layer):
    def __init__(self, num_tasks, num_shared_experts, num_task_experts, expert_dim, **kwargs):
        super(PLELayer, self).__init__(**kwargs)
        self.num_tasks = num_tasks
        self.num_shared_experts = num_shared_experts
        self.num_task_experts = num_task_experts
        self.expert_dim = expert_dim
        self.num_experts = num_shared_experts + num_task_experts
        
    def build(self, input_shape):
        input_dim = input_shape[-1]
        
        # 共享专家网络
        self.shared_experts = [
            Dense(self.expert_dim, activation='relu', name=f'shared_expert_{i}')
            for i in range(self.num_shared_experts)
        ]
        
        # 任务特定专家网络
        self.task_experts = [
            [
                Dense(self.expert_dim, activation='relu', name=f'task_{task}_expert_{i}')
                for i in range(self.num_task_experts)
            ]
            for task in range(self.num_tasks)
        ]
        
        # 门控网络
        self.gates = [
            Dense(self.num_experts, activation='softmax', name=f'gate_{task}')
            for task in range(self.num_tasks)
        ]
        
        super(PLELayer, self).build(input_shape)
        
    def call(self, inputs):
        # 生成共享专家输出
        shared_expert_outputs = [
            expert(inputs) for expert in self.shared_experts
        ]
        
        outputs = []
        for task in range(self.num_tasks):
            # 生成任务特定专家输出
            task_expert_outputs = [
                expert(inputs) for expert in self.task_experts[task]
            ]
            
            # 合并所有专家输出
            all_expert_outputs = shared_expert_outputs + task_expert_outputs  # (num_experts, batch, expert_dim)
            all_experts_concat = tf.stack(all_expert_outputs, axis=1)  # (batch, num_experts, expert_dim)
            
            # 计算门控权重
            gate_weights = self.gates[task](inputs)  # (batch, num_experts)
            gate_weights = tf.expand_dims(gate_weights, axis=-1)  # (batch, num_experts, 1)
            
            # 加权组合专家输出
            weighted_output = tf.reduce_sum(all_experts_concat * gate_weights, axis=1)  # (batch, expert_dim)
            outputs.append(weighted_output)
            
        return outputs

# ---------------------- 加载数据集 ----------------------
def load_dataset(npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"):
    data = np.load(npz_path, allow_pickle=True)
    signals = data["signals"]
    labels = data["labels"].astype(np.int32)
    jnr_vals = data["jnr_values"].astype(np.float32)
    fs = float(data["fs"])
    L = int(data["L"])
    metadata = data["metadata"]
    type2label = data["type_to_label"].item()
    label2type = {v: k for k, v in type2label.items()}
    type2name = data["interference_type_names"].item()
    label2name = {i: type2name[k] for k, i in type2label.items()}
    
    return {
        "signals": signals,
        "labels": labels,
        "jnr_values": jnr_vals,
        "fs": fs,
        "L": L,
        "metadata": metadata,
        "type2label": type2label,
        "label2type": label2type,
        "type2name": type2name,
        "label2name": label2name
    }

# ---------------------- 数据预处理 ----------------------
def preprocess_data(dataset):
    signals = dataset["signals"]
    labels = dataset["labels"]
    jnr_values = dataset["jnr_values"]
    metadata = dataset["metadata"]
    L = dataset["L"]

    # 1. 标准化信号
    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel() for s in signals])

    # 2. 构建检测标签（卫星信号为无干扰，其他为有干扰）
    no_key = "satellite_signal"
    det_labels = (labels != dataset["type2label"][no_key]).astype(np.float32)

    # 3. 构建参数标签（与 completely_fixed_model.py 保持一致）
    param_labels = []
    for sig in signals:
        power = sig ** 2
        threshold = 2.0 * np.mean(power)
        above_threshold = power > threshold

        if np.any(above_threshold):
            start_idx = np.argmax(above_threshold)
            end_idx = len(above_threshold) - np.argmax(above_threshold[::-1]) - 1
        else:
            start_idx, end_idx = 0, len(sig) - 1

        start_time = start_idx / dataset["fs"] * 1e3      # ms
        end_time   = end_idx   / dataset["fs"] * 1e3
        jnr_db = 10 * np.log10(np.mean(power) + 1e-12) - dataset.get("noise_power_db", 0)

        param_labels.append([start_time, end_time, jnr_db])
    param_labels = np.array(param_labels, dtype=np.float32)

    # 4. 过滤无效样本 —— 与 completely_fixed_model.py 完全一致
    mask = ~(
        np.any(np.isnan(signals), axis=1) |
        np.any(np.isinf(signals), axis=1) |
        np.isnan(det_labels) | np.isinf(det_labels) |
        np.isnan(labels) | np.isinf(labels) |
        np.any(np.isnan(param_labels), axis=1) | np.any(np.isinf(param_labels), axis=1)
    )

    print(f"🧹 丢弃 {np.sum(~mask)} / {len(mask)} 条无效样本")
    signals, det_labels, labels, param_labels, jnr_values = \
        signals[mask], det_labels[mask], labels[mask], param_labels[mask], jnr_values[mask]

    # 5. 数据集分割（保持原逻辑）
    X_train, X_tmp, y_det_train, y_det_tmp, y_type_train, y_type_tmp, y_param_train, y_param_tmp, jnr_train, jnr_tmp = train_test_split(
        signals, det_labels, labels, param_labels, jnr_values,
        test_size=0.3, random_state=42, stratify=labels
    )

    X_val, X_test, y_det_val, y_det_test, y_type_val, y_type_test, y_param_val, y_param_test, jnr_val, jnr_test = train_test_split(
        X_tmp, y_det_tmp, y_type_tmp, y_param_tmp, jnr_tmp,
        test_size=2/3, random_state=42, stratify=y_type_tmp
    )

    return {
        "X_train": X_train, "X_val": X_val, "X_test": X_test,
        "y_det_train": y_det_train, "y_det_val": y_det_val, "y_det_test": y_det_test,
        "y_type_train": y_type_train, "y_type_val": y_type_val, "y_type_test": y_type_test,
        "y_param_train": y_param_train, "y_param_val": y_param_val, "y_param_test": y_param_test,
        "jnr_values_train": jnr_train, "jnr_values_val": jnr_val, "jnr_values_test": jnr_test,
        "type2label": dataset["type2label"], "label2name": dataset["label2name"],
        "L": L
    }

# ---------------------- PLE 模型结构 ----------------------
def build_ple_model(input_shape, num_classes, num_shared_experts=2, num_task_experts=1, expert_dim=64):
    inputs = Input(shape=input_shape, dtype=tf.float32)
    
    # 共享主干特征提取
    x = Reshape((input_shape[0], 1))(inputs)
    x = Conv1D(64, 7, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    x = Conv1D(128, 5, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    x = Conv1D(256, 3, activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    shared_features = GlobalAveragePooling1D()(x)
    
    # PLE层 - 3个任务: 检测、分类、回归
    ple_outputs = PLELayer(
        num_tasks=3,
        num_shared_experts=num_shared_experts,
        num_task_experts=num_task_experts,
        expert_dim=expert_dim
    )(shared_features)
    
    # 任务特定特征处理
    det_feat = Dense(64, activation='relu')(ple_outputs[0])
    cls_feat = Dense(64, activation='relu')(ple_outputs[1])
    reg_feat = Dense(64, activation='relu')(ple_outputs[2])
    
    # 任务输出层
    det_out = Dense(1, activation='sigmoid', name='detection_output')(det_feat)
    cls_out = Dense(num_classes, activation='softmax', name='classification_output')(cls_feat)
    reg_out = Dense(3, activation='sigmoid', name='regression_output')(reg_feat)  # 使用sigmoid确保输出在[0,1]
    
    model = Model(inputs, [det_out, cls_out, reg_out])
    
    # 编译配置
    model.compile(
        optimizer=Adam(1e-3),
        loss={
            'detection_output': BinaryCrossentropy(),
            'classification_output': SparseCategoricalCrossentropy(),
            'regression_output': MeanSquaredError()
        },
        loss_weights={
            'detection_output': 0.8,
            'classification_output': 2.0,
            'regression_output': 0.5
        },
        metrics={
            'detection_output': BinaryAccuracy(),
            'classification_output': SparseCategoricalAccuracy(),
            'regression_output': 'mae'
        }
    )
    
    return model

# ---------------------- 训练单个 PLE 模型 ----------------------
def train_single_model(data, model_idx, epochs=120, batch=128):
    input_shape, num_classes = (data["L"],), len(data["type2label"])
    model = build_ple_model(input_shape, num_classes)
    
    # 计算分类任务的类别权重
    cls_weights = compute_class_weight(
        'balanced',
        classes=np.unique(data['y_type_train']),
        y=data['y_type_train']
    )
    cls_weights_dict = dict(enumerate(cls_weights))
    
    # 创建数据集
    ds_train = tf.data.Dataset.from_tensor_slices((
        data['X_train'],
        {
            'detection_output': data['y_det_train'],
            'classification_output': data['y_type_train'],
            'regression_output': data['y_param_train']
        }
    )).batch(batch).prefetch(tf.data.AUTOTUNE)
    
    ds_val = tf.data.Dataset.from_tensor_slices((
        data['X_val'],
        {
            'detection_output': data['y_det_val'],
            'classification_output': data['y_type_val'],
            'regression_output': data['y_param_val']
        }
    )).batch(batch).prefetch(tf.data.AUTOTUNE)
    
    # 设置回调函数
    ckpt = f"models/ple_model_1019_{model_idx}.keras"
    callbacks = [
        ReduceLROnPlateau(monitor='val_classification_output_sparse_categorical_accuracy',
                         factor=0.5, patience=10, min_lr=1e-5, mode='max', verbose=1),
        ModelCheckpoint(ckpt, save_best_only=True, save_weights_only=False,
                       monitor='val_classification_output_sparse_categorical_accuracy', 
                       mode='max', verbose=1)
    ]
    
    print(f"\n🔥 训练第 {model_idx + 1} 个PLE模型（适配v3归一化数据）...")
    start_time = time.time()
    history = model.fit(
        ds_train,
        validation_data=ds_val,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    end_time = time.time()
    training_time = end_time - start_time
    print(f"✅ 模型 {model_idx + 1} 训练完成，耗时 {training_time:.2f} 秒")
    
    return model, training_time

# ---------------------- 训练多个 PLE 模型（Ensemble） ----------------------
def train_ensemble(data, n_models=3, epochs=120, batch=128):
    models = []
    training_times = []
    for i in range(n_models):
        model, training_time = train_single_model(data, i, epochs=epochs, batch=batch)
        models.append(model)
        training_times.append(training_time)
    return models, training_times

# ---------------------- 主函数 ----------------------
def main():
    for d in ["models", "visualizations", "reports"]:
        os.makedirs(d, exist_ok=True)
    
    print("=" * 80)
    print("🚀 开始训练 PLE 多任务模型（检测/分类/回归，渐进分层提取，适配v3归一化数据）")
    print("=" * 80)
    
    # 加载并预处理数据
    dataset = load_dataset()
    data = preprocess_data(dataset)
    
    # 训练 Ensemble 模型
    n_models = 3
    epochs = 120
    batch_size = 128
    
    models, training_times = train_ensemble(data, n_models=n_models, epochs=epochs, batch=batch_size)
    
    # 打印训练时间
    print("\n" + "="*50)
    print("📈 训练时间统计")
    print("="*50)
    for i, t in enumerate(training_times):
        print(f"模型 {i + 1}: {t:.2f} 秒 ({t / 60:.2f} 分钟)")
    print(f"🔥 总训练时间: {sum(training_times):.2f} 秒 ({sum(training_times) / 60:.2f} 分钟)")
    print("="*50)
    
    print("\n✅ 所有PLE模型训练完成！模型已保存至 models/ple_model_1019_{i}.keras")
    print("📌 适配v3归一化数据的训练完成")

if __name__ == "__main__":
    main()

2025-10-23 19:27:15.864215: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-23 19:27:15.903730: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


✅ 成功设置中文字体
🚀 开始训练 PLE 多任务模型（检测/分类/回归，渐进分层提取，适配v3归一化数据）
🧹 丢弃 0 / 81000 条无效样本


2025-10-23 19:27:38.449544: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46590 MB memory:  -> device: 0, name: NVIDIA vGPU-48GB, pci bus id: 0000:16:00.0, compute capability: 8.9



🔥 训练第 1 个PLE模型（适配v3归一化数据）...
Epoch 1/120


2025-10-23 19:27:39.996066: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_2' with dtype float and shape [56700]
	 [[{{node Placeholder/_2}}]]
2025-10-23 19:27:41.967792: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8600
2025-10-23 19:27:42.161794: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:637] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2025-10-23 19:27:42.201470: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x5557c05eec60 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-10-23 19:27:42.201512: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA vGPU-48GB, Compute Capability 8.9
2025-10-23 19:



2025-10-23 19:27:51.089646: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_2' with dtype float and shape [8100]
	 [[{{node Placeholder/_2}}]]



Epoch 1: val_classification_output_sparse_categorical_accuracy improved from -inf to 0.28309, saving model to models/ple_model_1019_0.keras
Epoch 2/120
Epoch 2: val_classification_output_sparse_categorical_accuracy improved from 0.28309 to 0.69321, saving model to models/ple_model_1019_0.keras
Epoch 3/120
Epoch 3: val_classification_output_sparse_categorical_accuracy improved from 0.69321 to 0.72864, saving model to models/ple_model_1019_0.keras
Epoch 4/120
Epoch 4: val_classification_output_sparse_categorical_accuracy improved from 0.72864 to 0.73593, saving model to models/ple_model_1019_0.keras
Epoch 5/120
Epoch 5: val_classification_output_sparse_categorical_accuracy improved from 0.73593 to 0.76568, saving model to models/ple_model_1019_0.keras
Epoch 6/120
Epoch 6: val_classification_output_sparse_categorical_accuracy improved from 0.76568 to 0.77037, saving model to models/ple_model_1019_0.keras
Epoch 7/120
Epoch 7: val_classification_output_sparse_categorical_accuracy improved 

2025-10-23 19:41:33.113644: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_1' with dtype int32 and shape [56700]
	 [[{{node Placeholder/_1}}]]




2025-10-23 19:41:43.197023: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_2' with dtype float and shape [8100]
	 [[{{node Placeholder/_2}}]]



Epoch 1: val_classification_output_sparse_categorical_accuracy improved from -inf to 0.25432, saving model to models/ple_model_1019_1.keras
Epoch 2/120
Epoch 2: val_classification_output_sparse_categorical_accuracy improved from 0.25432 to 0.71383, saving model to models/ple_model_1019_1.keras
Epoch 3/120
Epoch 3: val_classification_output_sparse_categorical_accuracy improved from 0.71383 to 0.73568, saving model to models/ple_model_1019_1.keras
Epoch 4/120
Epoch 4: val_classification_output_sparse_categorical_accuracy improved from 0.73568 to 0.74914, saving model to models/ple_model_1019_1.keras
Epoch 5/120
Epoch 5: val_classification_output_sparse_categorical_accuracy improved from 0.74914 to 0.75901, saving model to models/ple_model_1019_1.keras
Epoch 6/120
Epoch 6: val_classification_output_sparse_categorical_accuracy improved from 0.75901 to 0.76914, saving model to models/ple_model_1019_1.keras
Epoch 7/120
Epoch 7: val_classification_output_sparse_categorical_accuracy improved 

2025-10-23 19:55:17.396939: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_2' with dtype float and shape [56700]
	 [[{{node Placeholder/_2}}]]




2025-10-23 19:55:27.710678: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype double and shape [8100,1024]
	 [[{{node Placeholder/_0}}]]



Epoch 1: val_classification_output_sparse_categorical_accuracy improved from -inf to 0.26889, saving model to models/ple_model_1019_2.keras
Epoch 2/120
Epoch 2: val_classification_output_sparse_categorical_accuracy improved from 0.26889 to 0.68679, saving model to models/ple_model_1019_2.keras
Epoch 3/120
Epoch 3: val_classification_output_sparse_categorical_accuracy improved from 0.68679 to 0.72716, saving model to models/ple_model_1019_2.keras
Epoch 4/120
Epoch 4: val_classification_output_sparse_categorical_accuracy improved from 0.72716 to 0.73889, saving model to models/ple_model_1019_2.keras
Epoch 5/120
Epoch 5: val_classification_output_sparse_categorical_accuracy improved from 0.73889 to 0.75741, saving model to models/ple_model_1019_2.keras
Epoch 6/120
Epoch 6: val_classification_output_sparse_categorical_accuracy improved from 0.75741 to 0.76259, saving model to models/ple_model_1019_2.keras
Epoch 7/120
Epoch 7: val_classification_output_sparse_categorical_accuracy did not i

In [3]:
# evaluate_ple_fixed_nrmse_columnwise.py
import os
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (confusion_matrix, accuracy_score, precision_score,
                             recall_score, f1_score, mean_absolute_error)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from matplotlib import font_manager as fm
import time

# ---------------------- 中文字体 ----------------------
def set_chinese_font():
    try:
        font_paths = ['/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
                      'C:/Windows/Fonts/simhei.ttf',
                      '/System/Library/Fonts/PingFang.ttc']
        for font_path in font_paths:
            if os.path.exists(font_path):
                fm.fontManager.addfont(font_path)
                plt.rcParams['font.family'] = fm.FontProperties(fname=font_path).get_name()
                plt.rcParams['axes.unicode_minus'] = False
                print(f"✅ 成功设置中文字体: {plt.rcParams['font.family']}")
                return True
        plt.rcParams['font.family'] = ['SimHei', 'Arial Unicode MS']
        plt.rcParams['axes.unicode_minus'] = False
        print("⚠️ 未找到指定字体，使用默认兼容字体")
        return True
    except Exception as e:
        print(f"❌ 字体设置失败: {e}")
        return False
set_chinese_font()

# ---------------------- 修复的PLE层（关键修复） ----------------------
class PLELayer(tf.keras.layers.Layer):
    def __init__(self, num_tasks, num_shared_experts, num_task_experts, expert_dim, name="ple_layer", **kwargs):
        # 修复：正确传递name参数给父类[4](@ref)
        super(PLELayer, self).__init__(name=name, **kwargs)
        self.num_tasks = num_tasks
        self.num_shared_experts = num_shared_experts
        self.num_task_experts = num_task_experts
        self.expert_dim = expert_dim
        
    def build(self, input_shape):
        input_dim = input_shape[-1]
        
        # 共享专家网络
        self.shared_experts = []
        for i in range(self.num_shared_experts):
            self.shared_experts.append(
                tf.keras.layers.Dense(self.expert_dim, activation='relu', name=f'shared_expert_{i}')
            )
        
        # 任务特定专家网络
        self.task_experts = []
        for task in range(self.num_tasks):
            task_expert_list = []
            for i in range(self.num_task_experts):
                task_expert_list.append(
                    tf.keras.layers.Dense(self.expert_dim, activation='relu', name=f'task_{task}_expert_{i}')
                )
            self.task_experts.append(task_expert_list)
        
        # 门控网络
        self.gates = []
        for task in range(self.num_tasks):
            self.gates.append(
                tf.keras.layers.Dense(self.num_shared_experts + self.num_task_experts, 
                                     activation='softmax', name=f'gate_{task}')
            )
        
        super(PLELayer, self).build(input_shape)
        
    def call(self, inputs):
        # 生成共享专家输出
        shared_expert_outputs = [expert(inputs) for expert in self.shared_experts]
        
        outputs = []
        for task in range(self.num_tasks):
            # 生成任务特定专家输出
            task_expert_outputs = [expert(inputs) for expert in self.task_experts[task]]
            
            # 合并所有专家输出
            all_expert_outputs = shared_expert_outputs + task_expert_outputs
            all_experts_concat = tf.stack(all_expert_outputs, axis=1)
            
            # 计算门控权重
            gate_weights = self.gates[task](inputs)
            gate_weights = tf.expand_dims(gate_weights, axis=-1)
            
            # 加权组合专家输出
            weighted_output = tf.reduce_sum(all_experts_concat * gate_weights, axis=1)
            outputs.append(weighted_output)
            
        return outputs
    
    def get_config(self):
        # 修复：实现get_config方法用于序列化[7,8](@ref)
        config = super(PLELayer, self).get_config()
        config.update({
            'num_tasks': self.num_tasks,
            'num_shared_experts': self.num_shared_experts,
            'num_task_experts': self.num_task_experts,
            'expert_dim': self.expert_dim
        })
        return config

# ---------------------- 加载数据集 ----------------------
def load_dataset(npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz"):
    data = np.load(npz_path, allow_pickle=True)
    signals = data["signals"]
    labels = data["labels"].astype(np.int32)
    jnr_values = data["jnr_values"].astype(np.float32)
    fs = float(data["fs"])
    L = int(data["L"])
    noise_power_db = float(data["noise_power_db"])
    type2label = data["type_to_label"].item()

    interference_type_names = {
        "satellite_signal": "Satellite_Signal",
        "single_tone": "Single_Tone",
        "comb_spectra": "Comb_Spectra",
        "sweeping": "Sweeping-LFM",
        "pulse": "Pulse",
        "frequency_hopping": "Frequency_Hopping",
        "noise_fm": "Noise_FM",
        "noise_am": "Noise_AM",
        "random_combination": "Random_Combination"
    }
    label2name = {i: interference_type_names[k] for k, i in type2label.items()}
    return {
        "signals": signals,
        "labels": labels,
        "jnr_values": jnr_values,
        "fs": fs,
        "L": L,
        "noise_power_db": noise_power_db,
        "type2label": type2label,
        "label2name": label2name
    }

# ---------------------- 在线估算标签 ----------------------
def estimate_start_end(signal, fs, threshold_factor=2.0):
    power = signal ** 2
    avg_pow = np.mean(power)
    thresh = threshold_factor * avg_pow
    above = power > thresh
    diff = np.diff(above.astype(int))
    starts = np.where(diff == 1)[0]
    ends = np.where(diff == -1)[0]
    if len(starts) == 0 or len(ends) == 0:
        return 0.0, 0.0
    return float(starts[0] / fs * 1e3), float(ends[-1] / fs * 1e3)

def estimate_jnr(signal, noise_power_db):
    total_power = 10 * np.log10(np.mean(signal ** 2) + 1e-12)
    return float(total_power - noise_power_db)

# ---------------------- 数据预处理 ----------------------
def preprocess_data(dataset):
    signals = dataset["signals"]
    labels = dataset["labels"]
    jnr_values = dataset["jnr_values"]
    L = dataset["L"]
    fs = dataset["fs"]
    noise_power_db = dataset["noise_power_db"]

    signals = np.array([StandardScaler().fit_transform(s.reshape(-1, 1)).ravel() for s in signals])

    no_key = "satellite_signal"
    det_labels = (labels != dataset["type2label"][no_key]).astype(np.float32)

    # 使用在线估算方法
    param_labels = []
    for i, signal in enumerate(signals):
        start_time, end_time = estimate_start_end(signal, fs)
        jnr_db = jnr_values[i]
        param_labels.append([start_time, end_time, jnr_db])
    param_labels = np.array(param_labels, dtype=np.float32)

    X_test = signals
    y_det_test = det_labels
    y_type_test = labels
    y_param_test = param_labels
    jnr_test = jnr_values
    label2name = dataset["label2name"]

    return {
        "X_test": X_test,
        "y_det_test": y_det_test,
        "y_type_test": y_type_test,
        "y_param_test": y_param_test,
        "jnr_values_test": jnr_test,
        "label2name": label2name,
        "L": L
    }

# ---------------------- 按列分别归一化 NRMSE ----------------------
def nrmse_columnwise(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2, axis=0))
    y_range = np.max(y_true, axis=0) - np.min(y_true, axis=0)
    return rmse / (y_range + 1e-8)

# ---------------------- 修复的模型加载和预测 ----------------------
def load_and_predict(model_paths, X_test):
    all_det, all_cls, all_reg = [], [], []
    
    # 修复：正确注册自定义对象[6,8](@ref)
    custom_objects = {
        'PLELayer': PLELayer
    }
    
    for path in model_paths:
        print(f"🔁 加载模型: {path}")
        try:
            # 修复：使用custom_objects参数[7,10](@ref)
            model = tf.keras.models.load_model(path, custom_objects=custom_objects, compile=False)
            start_time = time.time()
            det, cls, reg = model.predict(X_test, verbose=0)
            end_time = time.time()
            prediction_time = end_time - start_time
            print(f"✅ 预测时间: {prediction_time:.2f} 秒")
            all_det.append(det)
            all_cls.append(cls)
            all_reg.append(reg)
        except Exception as e:
            print(f"❌ 加载模型失败: {e}")
            continue
    
    if len(all_det) == 0:
        raise ValueError("所有模型加载失败，请检查模型路径和自定义层定义")
    
    # 集成平均
    avg_det = (np.mean(all_det, axis=0) > 0.5).astype(int).ravel()
    avg_cls = np.argmax(np.mean(all_cls, axis=0), axis=1)
    avg_reg = np.mean(all_reg, axis=0)
    
    return avg_det, avg_cls, avg_reg

# ---------------------- 绘制混淆矩阵 ----------------------
def plot_confusion_matrix(cm, labels, title, xlabel, ylabel, filename, dpi=150, rotate_x=False):
    cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_normalized = np.nan_to_num(cm_normalized)

    plt.figure(figsize=(12, 10))
    ax = sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                     xticklabels=labels, yticklabels=labels, square=True, annot_kws={"size": 14})
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)

    plt.title(title, pad=20, fontsize=18)
    plt.xlabel(xlabel, fontsize=16)
    plt.ylabel(ylabel, fontsize=16)
    plt.xticks(rotation=45 if rotate_x else 0, ha='right' if rotate_x else 'center', fontsize=14)
    plt.yticks(rotation=0, fontsize=14)
    plt.tight_layout()
    plt.savefig(filename, dpi=dpi)
    plt.close()

# ---------------------- 主评估函数 ----------------------
def evaluate(models_dir="models",
             npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz",
             wanted_jnr=np.arange(-10, 31, 5),
             dpi=150):
    
    os.makedirs("visualizations", exist_ok=True)
    os.makedirs("reports", exist_ok=True)

    print("\n" + "="*60)
    print("📊 开始评估 PLE 多任务模型（修复版）")
    print("="*60)
    
    # 加载数据
    dataset = load_dataset(npz_path)
    data = preprocess_data(dataset)

    X_test = data["X_test"]
    y_det_test = data["y_det_test"]
    y_type_test = data["y_type_test"]
    y_param_test = data["y_param_test"]
    jnr_test = data["jnr_values_test"]
    label2name = data["label2name"]

    # 修复：使用正确的模型路径格式
    model_paths = [os.path.join(models_dir, f"ple_model_1019_{i}.keras") for i in range(3)]
    
    # 检查模型文件是否存在
    for path in model_paths:
        if not os.path.exists(path):
            print(f"⚠️ 警告: 模型文件 {path} 不存在")
    
    # 加载模型并预测
    avg_det, avg_cls, avg_reg = load_and_predict(model_paths, X_test)

    # 1. 检测混淆矩阵
    cm_det = confusion_matrix(y_det_test, avg_det)
    plot_confusion_matrix(
        cm=cm_det,
        labels=['No Interference', 'Interference'],
        title='PLE Interference Detection Confusion Matrix',
        xlabel='Predicted',
        ylabel='True',
        filename='visualizations/PLE_detection_confusion_matrix.png',
        dpi=dpi,
        rotate_x=False
    )

    # 2. 分类混淆矩阵
    cm_type = confusion_matrix(y_type_test, avg_cls)
    plot_confusion_matrix(
        cm=cm_type,
        labels=[label2name[i] for i in sorted(label2name.keys())],
        title='PLE Classification Confusion Matrix',
        xlabel='Predicted',
        ylabel='True',
        filename='visualizations/PLE_classification_confusion_matrix.png',
        dpi=dpi,
        rotate_x=True
    )

    # 3. JNR vs 准确率
    jnr_acc = []
    for jnr in wanted_jnr:
        mask = jnr_test == jnr
        if np.sum(mask) == 0:
            jnr_acc.append(np.nan)
        else:
            acc = accuracy_score(y_type_test[mask], avg_cls[mask])
            jnr_acc.append(acc)

    plt.figure(figsize=(8, 5))
    valid_mask = ~np.isnan(jnr_acc)
    plt.plot(wanted_jnr[valid_mask], np.array(jnr_acc)[valid_mask], marker='o', linewidth=2)
    if np.any(~valid_mask):
        plt.scatter(wanted_jnr[~valid_mask], [1.0] * np.sum(~valid_mask),
                    facecolors='none', edgecolors='r', s=60)
    plt.xlabel('JNR (dB)', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('PLE Classification Accuracy vs JNR', fontsize=16)
    plt.xticks(wanted_jnr)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.savefig('visualizations/PLE_jnr_vs_accuracy.png', dpi=dpi)
    plt.close()

    # 4. 指标计算
    det_acc = accuracy_score(y_det_test, avg_det)
    cls_acc = accuracy_score(y_type_test, avg_cls)
    cls_precision = precision_score(y_type_test, avg_cls, average='weighted', zero_division=0)
    cls_recall = recall_score(y_type_test, avg_cls, average='weighted', zero_division=0)
    cls_f1 = f1_score(y_type_test, avg_cls, average='weighted', zero_division=0)
    param_mae = mean_absolute_error(y_param_test, avg_reg, multioutput='raw_values')

    # 按列分别归一化 NRMSE
    param_nrmse = nrmse_columnwise(y_param_test, avg_reg)
    param_names = ['Start Time (ms)', 'End Time (ms)', 'JNR (dB)']

    print("\n" + "="*50)
    print("📊 PLE 评估结果（修复版）")
    print("="*50)
    print(f"检测准确率: {det_acc:.4f}")
    print(f"分类准确率: {cls_acc:.4f}")
    print(f"分类精确率: {cls_precision:.4f}, 召回率: {cls_recall:.4f}, F1: {cls_f1:.4f}")
    print("\n参数估计误差（MAE & 按列 NRMSE）:")
    for i, (name, mae, nrmse) in enumerate(zip(param_names, param_mae, param_nrmse)):
        print(f"  {name}: MAE = {mae:.4f}, NRMSE = {nrmse:.4f}")
    print(f"  平均 NRMSE（三列分别归一化） = {np.mean(param_nrmse):.4f}")

    print("\nJNR 准确率:")
    for j, acc in zip(wanted_jnr, jnr_acc):
        acc_str = f"{acc:.4f}" if not np.isnan(acc) else 'N/A'
        print(f"  {int(j)}dB: {acc_str}")

    # 5. 保存报告
    report = {
        "detection_accuracy": float(det_acc),
        "classification_accuracy": float(cls_acc),
        "classification_precision": float(cls_precision),
        "classification_recall": float(cls_recall),
        "classification_f1": float(cls_f1),
        "parameter_mae": [float(m) for m in param_mae],
        "parameter_nrmse": [float(n) for n in param_nrmse],
        "average_columnwise_nrmse": float(np.mean(param_nrmse)),
        "parameter_details": {
            param_names[i]: {"mae": float(param_mae[i]), "nrmse": float(param_nrmse[i])}
            for i in range(len(param_names))
        },
        "jnr_accuracies": {f"{int(j)}dB": float(acc) if not np.isnan(acc) else None 
                          for j, acc in zip(wanted_jnr, jnr_acc)}
    }
    
    with open("reports/PLE_evaluation_report_fixed.json", "w", encoding='utf-8') as f:
        json.dump(report, f, indent=4, ensure_ascii=False)

    print("\n✅ PLE 评估完成！混淆矩阵与报告已保存。")

# ---------------------- 主函数 ----------------------
if __name__ == "__main__":
    evaluate(models_dir="models",
             npz_path="/root/yxun/20250826/dataset/interference_signals_natural_same_freq_1019.npz",
             wanted_jnr=np.arange(-10, 31, 5),
             dpi=150)

✅ 成功设置中文字体: ['WenQuanYi Micro Hei']

📊 开始评估 PLE 多任务模型（修复版）
🔁 加载模型: models/ple_model_1019_0.keras
✅ 预测时间: 9.52 秒
🔁 加载模型: models/ple_model_1019_1.keras
✅ 预测时间: 8.53 秒
🔁 加载模型: models/ple_model_1019_2.keras
✅ 预测时间: 8.80 秒

📊 PLE 评估结果（修复版）
检测准确率: 0.9813
分类准确率: 0.9402
分类精确率: 0.9414, 召回率: 0.9402, F1: 0.9403

参数估计误差（MAE & 按列 NRMSE）:
  Start Time (ms): MAE = 0.0046, NRMSE = 0.0872
  End Time (ms): MAE = 0.0047, NRMSE = 0.0880
  JNR (dB): MAE = 13.3333, NRMSE = 0.4082
  平均 NRMSE（三列分别归一化） = 0.1945

JNR 准确率:
  -10dB: 0.7970
  -5dB: 0.8747
  0dB: 0.9241
  5dB: 0.9508
  10dB: 0.9652
  15dB: 0.9790
  20dB: 0.9894
  25dB: 0.9901
  30dB: 0.9914

✅ PLE 评估完成！混淆矩阵与报告已保存。


============================================================
📊 开始评估 PLE 多任务模型（修复版）
============================================================
🔁 加载模型: models/ple_model_1019_0.keras
✅ 预测时间: 9.52 秒
🔁 加载模型: models/ple_model_1019_1.keras
✅ 预测时间: 8.53 秒
🔁 加载模型: models/ple_model_1019_2.keras
✅ 预测时间: 8.80 秒

==================================================
📊 PLE 评估结果（修复版）
==================================================
检测准确率: 0.9813
分类准确率: 0.9402
分类精确率: 0.9414, 召回率: 0.9402, F1: 0.9403

参数估计误差（MAE & 按列 NRMSE）:
  Start Time (ms): MAE = 0.0046, NRMSE = 0.0872
  End Time (ms): MAE = 0.0047, NRMSE = 0.0880
  JNR (dB): MAE = 13.3333, NRMSE = 0.4082
  平均 NRMSE（三列分别归一化） = 0.1945

JNR 准确率:
  -10dB: 0.7970
  -5dB: 0.8747
  0dB: 0.9241
  5dB: 0.9508
  10dB: 0.9652
  15dB: 0.9790
  20dB: 0.9894
  25dB: 0.9901
  30dB: 0.9914


  ============================================================
📊 开始评估 PLE 多任务模型（修复版）
============================================================
🔁 加载模型: models/ple_model_1019_0.keras
✅ 预测时间: 9.85 秒
🔁 加载模型: models/ple_model_1019_1.keras
✅ 预测时间: 8.91 秒
🔁 加载模型: models/ple_model_1019_2.keras
✅ 预测时间: 9.25 秒

==================================================
📊 PLE 评估结果（修复版）
==================================================
检测准确率: 0.9816
分类准确率: 0.9428
分类精确率: 0.9440, 召回率: 0.9428, F1: 0.9430

参数估计误差（MAE & 按列 NRMSE）:
  Start Time (ms): MAE = 0.0046, NRMSE = 0.0871
  End Time (ms): MAE = 0.0046, NRMSE = 0.0850
  JNR (dB): MAE = 13.3333, NRMSE = 0.4082
  平均 NRMSE（三列分别归一化） = 0.1934

JNR 准确率:
  -10dB: 0.7972
  -5dB: 0.8732
  0dB: 0.9349
  5dB: 0.9548