<a href="https://colab.research.google.com/github/RayJohn-404notfound/Machine_learning_fire_identificaion/blob/main/testing_real_fire.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')



Mounted at /content/drive


In [None]:
!pip install pynvml --quiet


In [None]:
pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [None]:
import os
import glob
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, roc_curve, auc,
    precision_recall_curve, average_precision_score,
    classification_report
)
from typing import List
from dataclasses import dataclass

# ========== 配置类 ==========
@dataclass
class CRNNConfig:
    """CRNN模型配置"""
    # 模型架构配置
    num_classes: int = 2
    cnn_channels: List[int] = None
    kernel_size: int = 3
    padding: int = 1
    pool_size: tuple = (2, 2)

    # RNN配置
    hidden_size: int = 128
    rnn_layers: int = 2
    bidirectional: bool = True
    rnn_type: str = "GRU"  # "GRU" or "LSTM"

    # Dropout配置
    dropout_cnn: float = 0.35
    dropout_rnn: float = 0.55

    # 音频预处理配置
    sample_rate: int = 16000
    n_fft: int = 1024
    hop_length: int = 320
    n_mels: int = 64

    def __post_init__(self):
        if self.cnn_channels is None:
            self.cnn_channels = [16, 32]

    def get_rnn_output_size(self) -> int:
        """计算RNN输出尺寸"""
        multiplier = 2 if self.bidirectional else 1
        return self.hidden_size * multiplier

# ========== 配置 ==========
MODEL_PATH = "/content/drive/MyDrive/trained_models/best_model_crnn.pt"
FIRE_DIR = "/content/drive/MyDrive/processed_audio_validation/processed_audio_fire"
NONFIRE_DIR = "/content/drive/MyDrive/processed_audio_validation/processed_audio_nonfire"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 使用与训练时相同的配置
config = CRNNConfig()

# ========== CRNN模型定义 ==========
class CRNN(nn.Module):
    def __init__(self, config: CRNNConfig, input_shape=(1, 64, 128)):
        super(CRNN, self).__init__()
        self.config = config

        # 构建CNN层
        layers = []
        in_channels = input_shape[0]

        for out_channels in config.cnn_channels:
            layers.extend([
                nn.Conv2d(in_channels, out_channels,
                         kernel_size=config.kernel_size,
                         padding=config.padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.MaxPool2d(config.pool_size)
            ])
            in_channels = out_channels

        self.conv = nn.Sequential(*layers)
        self.dropout_cnn = nn.Dropout(config.dropout_cnn)

        # 动态计算RNN输入大小
        self.rnn_input_size = self._calculate_rnn_input_size(input_shape)

        # 构建RNN层
        if config.rnn_type == "GRU":
            self.rnn = nn.GRU(
                input_size=self.rnn_input_size,
                hidden_size=config.hidden_size,
                num_layers=config.rnn_layers,
                batch_first=True,
                bidirectional=config.bidirectional
            )
        elif config.rnn_type == "LSTM":
            self.rnn = nn.LSTM(
                input_size=self.rnn_input_size,
                hidden_size=config.hidden_size,
                num_layers=config.rnn_layers,
                batch_first=True,
                bidirectional=config.bidirectional
            )
        else:
            raise ValueError(f"Unsupported RNN type: {config.rnn_type}")

        self.dropout_rnn = nn.Dropout(config.dropout_rnn)

        # 注意力机制和分类层
        rnn_output_size = config.get_rnn_output_size()
        self.attn_layer = nn.Linear(rnn_output_size, 1)
        self.fc = nn.Linear(rnn_output_size, config.num_classes)

    def _calculate_rnn_input_size(self, input_shape):
        """动态计算RNN输入大小"""
        dummy_input = torch.zeros(1, *input_shape)
        with torch.no_grad():
            out = self.conv(dummy_input)
            out = self.dropout_cnn(out)
            _, channels, height, _ = out.shape
        return channels * height

    def forward(self, x):
        # CNN特征提取
        x = self.conv(x)
        x = self.dropout_cnn(x)

        # 重塑为RNN输入格式
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous().view(x.size(0), x.size(1), -1)

        # RNN处理
        if self.config.rnn_type == "LSTM":
            rnn_out, _ = self.rnn(x)
        else:  # GRU
            rnn_out, _ = self.rnn(x)

        rnn_out = self.dropout_rnn(rnn_out)

        # 注意力机制
        attn_weights = torch.softmax(self.attn_layer(rnn_out), dim=1)
        context = torch.sum(attn_weights * rnn_out, dim=1)

        return self.fc(context)

# ========== 音频预处理 ==========
def preprocess_audio(file_path, config: CRNNConfig):
    """预处理音频文件，转换为MelSpectrogram"""
    # 读取音频
    waveform, sr = torchaudio.load(file_path)

    # 重采样
    if sr != config.sample_rate:
        resampler = torchaudio.transforms.Resample(sr, config.sample_rate)
        waveform = resampler(waveform)

    # 转单声道
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # 计算Mel频谱图
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=config.sample_rate,
        n_fft=config.n_fft,
        hop_length=config.hop_length,
        n_mels=config.n_mels,
        window_fn=torch.hann_window
    )

    amp_to_db = torchaudio.transforms.AmplitudeToDB(stype='power')

    mel_spec = mel_transform(waveform)
    mel_spec_db = amp_to_db(mel_spec)

    return mel_spec_db.unsqueeze(0)

# ========== 智能阈值搜索 ==========
def find_best_threshold(fire_probs, true_labels):
    """
    智能搜索最佳阈值，优先考虑F1分数，然后是准确率
    """
    thresholds = np.arange(0.01, 0.99, 0.01)
    best_threshold = 0.5
    best_score = 0.0
    results = []

    fire_probs = np.array(fire_probs)
    true_labels = np.array(true_labels)

    for threshold in thresholds:
        predictions = (fire_probs > threshold).astype(int)

        # 计算混淆矩阵元素
        tp = np.sum((predictions == 1) & (true_labels == 1))
        tn = np.sum((predictions == 0) & (true_labels == 0))
        fp = np.sum((predictions == 1) & (true_labels == 0))
        fn = np.sum((predictions == 0) & (true_labels == 1))

        # 计算指标
        accuracy = (tp + tn) / len(true_labels) if len(true_labels) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        # 综合分数：F1分数为主，准确率为辅
        combined_score = f1 * 0.7 + accuracy * 0.3

        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'combined_score': combined_score,
            'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
        })

        if combined_score > best_score:
            best_score = combined_score
            best_threshold = threshold

    return best_threshold, results

# ========== 主要推理流程 ==========
def main():
    print("🚀 开始CRNN智能推理...")

    # 加载模型
    print("📁 加载模型...")

    # 获取第一个音频文件来确定输入形状
    fire_files = glob.glob(os.path.join(FIRE_DIR, "*.wav"))
    if fire_files:
        sample_input = preprocess_audio(fire_files[0], config)
        input_shape = tuple(sample_input.shape[1:])  # 去掉batch维度
    else:
        input_shape = (1, 64, 128)  # 默认形状
        print("⚠️  使用默认输入形状，可能不准确")

    model = CRNN(config, input_shape).to(DEVICE)

    try:
        state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(state_dict)
        model.eval()
        print("✅ 模型加载成功!")
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        print("🔧 尝试兼容性加载...")

        try:
            # 兼容性加载：使用strict=False
            state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
            model.load_state_dict(state_dict, strict=False)
            model.eval()
            print("✅ 兼容性加载成功!")
        except Exception as e2:
            print(f"❌ 兼容性加载也失败: {e2}")
            return

    # 获取文件列表
    fire_files = glob.glob(os.path.join(FIRE_DIR, "*.wav"))
    nonfire_files = glob.glob(os.path.join(NONFIRE_DIR, "*.wav"))

    if not fire_files or not nonfire_files:
        print("❌ 未找到音频文件，请检查路径设置")
        return

    print(f"📊 找到 {len(fire_files)} 个火灾音频，{len(nonfire_files)} 个非火灾音频")

    # 收集预测结果
    print("\n🔍 第一阶段：收集预测结果...")
    all_fire_probs = []
    all_true_labels = []
    all_filenames = []

    # 处理火灾音频
    print("  处理火灾音频...")
    for i, file_path in enumerate(sorted(fire_files)):
        try:
            input_tensor = preprocess_audio(file_path, config).to(DEVICE)
            with torch.no_grad():
                output = model(input_tensor)
                probs = torch.softmax(output, dim=1)
                fire_prob = probs[0][1].item()  # FIRE类别的概率

            all_fire_probs.append(fire_prob)
            all_true_labels.append(1)
            all_filenames.append(os.path.basename(file_path))
            print(f"    {i+1:2d}. {os.path.basename(file_path)[:35]:35s} Fire概率: {fire_prob:.4f}")

        except Exception as e:
            print(f"    ❌ 处理 {file_path} 时出错: {e}")

    # 处理非火灾音频
    print("  处理非火灾音频...")
    for i, file_path in enumerate(sorted(nonfire_files)):
        try:
            input_tensor = preprocess_audio(file_path, config).to(DEVICE)
            with torch.no_grad():
                output = model(input_tensor)
                probs = torch.softmax(output, dim=1)
                fire_prob = probs[0][1].item()  # FIRE类别的概率

            all_fire_probs.append(fire_prob)
            all_true_labels.append(0)
            all_filenames.append(os.path.basename(file_path))
            print(f"    {i+1:2d}. {os.path.basename(file_path)[:35]:35s} Fire概率: {fire_prob:.4f}")

        except Exception as e:
            print(f"    ❌ 处理 {file_path} 时出错: {e}")

    if not all_fire_probs:
        print("❌ 没有成功处理任何音频文件")
        return

    # 寻找最佳阈值
    print(f"\n🎯 第二阶段：智能搜索最佳阈值...")
    best_threshold, threshold_results = find_best_threshold(all_fire_probs, all_true_labels)

    # 显示最佳阈值结果
    best_result = next(r for r in threshold_results if r['threshold'] == best_threshold)
    print(f"\n✨ 最佳阈值: {best_threshold:.3f}")
    print(f"   准确率: {best_result['accuracy']:.3f}")
    print(f"   精确率: {best_result['precision']:.3f}")
    print(f"   召回率: {best_result['recall']:.3f}")
    print(f"   F1分数: {best_result['f1']:.3f}")

    # 显示Top5阈值
    sorted_results = sorted(threshold_results, key=lambda x: x['combined_score'], reverse=True)
    print(f"\n🏆 Top 5 最佳阈值:")
    print(f"{'阈值':>6} {'准确率':>6} {'精确率':>6} {'召回率':>6} {'F1':>6} {'综合分':>6}")
    print("-" * 42)
    for result in sorted_results[:5]:
        print(f"{result['threshold']:>6.3f} {result['accuracy']:>6.3f} {result['precision']:>6.3f} "
              f"{result['recall']:>6.3f} {result['f1']:>6.3f} {result['combined_score']:>6.3f}")

    # 使用最佳阈值进行最终预测
    print(f"\n🚀 第三阶段：使用最佳阈值 ({best_threshold:.3f}) 进行最终预测...")

    final_predictions = (np.array(all_fire_probs) > best_threshold).astype(int)
    correct_count = np.sum(final_predictions == np.array(all_true_labels))

    print(f"\n📋 详细预测结果:")
    for filename, fire_prob, true_label, pred_label in zip(all_filenames, all_fire_probs, all_true_labels, final_predictions):
        true_str = "FIRE  " if true_label == 1 else "NOFIRE"
        pred_str = "FIRE  " if pred_label == 1 else "NOFIRE"
        status = "✅" if pred_label == true_label else "❌"
        print(f"{status} [{true_str}] {filename[:35]:35s} → {pred_str} (概率: {fire_prob:.4f})")

    # 性能总结
    total_samples = len(all_true_labels)
    fire_samples = sum(all_true_labels)
    nofire_samples = total_samples - fire_samples

    fire_correct = sum(1 for i in range(total_samples)
                      if all_true_labels[i] == 1 and final_predictions[i] == 1)
    nofire_correct = sum(1 for i in range(total_samples)
                        if all_true_labels[i] == 0 and final_predictions[i] == 0)

    print(f"\n" + "="*60)
    print(f"🎯 最终性能总结")
    print(f"="*60)
    print(f"总体准确率: {correct_count}/{total_samples} = {correct_count/total_samples:.2%}")
    print(f"🔥 火灾检测准确率: {fire_correct}/{fire_samples} = {fire_correct/fire_samples:.2%}")
    print(f"❄️  非火灾检测准确率: {nofire_correct}/{nofire_samples} = {nofire_correct/nofire_samples:.2%}")
    print(f"🎯 最佳阈值: {best_threshold:.3f}")
    print(f"📊 F1分数: {best_result['f1']:.3f}")

    # 生成详细报告
    print(f"\n📊 详细分类报告:")
    report = classification_report(all_true_labels, final_predictions,
                                 target_names=["NOFIRE", "FIRE"])
    print(report)

    # 生成可视化图表
    print(f"\n📈 生成分析图表...")
    generate_plots(all_fire_probs, all_true_labels, final_predictions,
                  best_threshold, threshold_results)

    print(f"\n🎉 CRNN智能推理完成！")
    print(f"💡 关键发现：使用阈值 {best_threshold:.3f} 可获得 {correct_count/total_samples:.1%} 的准确率")

def generate_plots(fire_probs, true_labels, predictions, best_threshold, threshold_results):
    """生成分析图表"""

    # 1. 混淆矩阵
    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["NOFIRE", "FIRE"], yticklabels=["NOFIRE", "FIRE"])
    plt.title(f"混淆矩阵 - CRNN模型 (阈值: {best_threshold:.3f})")
    plt.xlabel("预测")
    plt.ylabel("真实")
    plt.tight_layout()
    plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 2. ROC曲线
    fpr, tpr, _ = roc_curve(true_labels, fire_probs)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f"ROC AUC = {roc_auc:.3f}", linewidth=2)
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    plt.title("ROC曲线 - CRNN模型")
    plt.xlabel("假正率")
    plt.ylabel("真正率")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("roc_curve_crnn_optimized.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 3. PR曲线
    precision, recall, _ = precision_recall_curve(true_labels, fire_probs)
    ap_score = average_precision_score(true_labels, fire_probs)
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, label=f"PR AUC = {ap_score:.3f}", linewidth=2)
    plt.title("Precision-Recall曲线 - CRNN模型")
    plt.xlabel("召回率")
    plt.ylabel("精确率")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("pr_curve_crnn_optimized.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 4. 概率分布
    fire_probs_fire = [fire_probs[i] for i in range(len(fire_probs)) if true_labels[i] == 1]
    fire_probs_nofire = [fire_probs[i] for i in range(len(fire_probs)) if true_labels[i] == 0]

    plt.figure(figsize=(10, 6))
    plt.hist(fire_probs_nofire, bins=20, alpha=0.7, label='非火灾样本', color='blue', density=True)
    plt.hist(fire_probs_fire, bins=20, alpha=0.7, label='火灾样本', color='red', density=True)
    plt.axvline(x=best_threshold, color='black', linestyle='--', linewidth=2,
               label=f'最佳阈值 = {best_threshold:.3f}')
    plt.axvline(x=0.5, color='gray', linestyle=':', alpha=0.7, label='默认阈值 = 0.5')

    plt.xlabel('Fire概率')
    plt.ylabel('密度')
    plt.title('Fire概率分布 - CRNN模型')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # 添加统计信息
    plt.text(0.05, 0.95, f'火灾样本均值: {np.mean(fire_probs_fire):.3f}',
             transform=plt.gca().transAxes, fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.3))
    plt.text(0.05, 0.88, f'非火灾样本均值: {np.mean(fire_probs_nofire):.3f}',
             transform=plt.gca().transAxes, fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="blue", alpha=0.3))

    plt.tight_layout()
    plt.savefig("fire_prob_distribution_crnn.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 5. 阈值性能曲线
    thresholds = [r['threshold'] for r in threshold_results]
    accuracies = [r['accuracy'] for r in threshold_results]
    f1_scores = [r['f1'] for r in threshold_results]

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, accuracies, label='准确率', linewidth=2, color='blue')
    plt.plot(thresholds, f1_scores, label='F1分数', linewidth=2, color='red')
    plt.axvline(x=best_threshold, color='black', linestyle='--',
               label=f'最佳阈值 = {best_threshold:.3f}')
    plt.axvline(x=0.5, color='gray', linestyle=':', alpha=0.7, label='默认阈值 = 0.5')

    plt.xlabel('阈值')
    plt.ylabel('分数')
    plt.title('阈值 vs 性能 - CRNN模型')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("threshold_performance_crnn.png", dpi=300, bbox_inches='tight')
    plt.close()

    print("  ✅ 图表已保存:")
    print("     - confusion_matrix_crnn_optimized.png")
    print("     - roc_curve_crnn_optimized.png")
    print("     - pr_curve_crnn_optimized.png")
    print("     - fire_prob_distribution_crnn.png")
    print("     - threshold_performance_crnn.png")

if __name__ == "__main__":
    main()

🚀 开始CRNN智能推理...
📁 加载模型...
✅ 模型加载成功!
📊 找到 57 个火灾音频，32 个非火灾音频

🔍 第一阶段：收集预测结果...
  处理火灾音频...
     1. fire_valid_01.wav                   Fire概率: 0.2368
     2. fire_valid_02.wav                   Fire概率: 0.9374
     3. fire_valid_03.wav                   Fire概率: 0.8125
     4. fire_valid_04.wav                   Fire概率: 0.8634
     5. fire_valid_05.wav                   Fire概率: 0.9631
     6. fire_valid_06.wav                   Fire概率: 0.9654
     7. fire_valid_07.wav                   Fire概率: 0.9561
     8. fire_valid_08.wav                   Fire概率: 0.9578
     9. fire_valid_09.wav                   Fire概率: 0.7510
    10. fire_valid_10.wav                   Fire概率: 0.9715
    11. fire_valid_11.wav                   Fire概率: 0.9328
    12. fire_valid_12.wav                   Fire概率: 0.9568
    13. fire_valid_13.wav                   Fire概率: 0.9204
    14. fire_valid_14.wav                   Fire概率: 0.9766
    15. fire_valid_15.wav                   Fire概率: 0.8843
    16. fire_valid_16.wav

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_crnn_optimized.png", dpi=300, bbox_inches='tight')
 

  ✅ 图表已保存:
     - confusion_matrix_crnn_optimized.png
     - roc_curve_crnn_optimized.png
     - pr_curve_crnn_optimized.png
     - fire_prob_distribution_crnn.png
     - threshold_performance_crnn.png

🎉 CRNN智能推理完成！
💡 关键发现：使用阈值 0.220 可获得 93.3% 的准确率


In [None]:
import os
import glob
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, roc_curve, auc,
    precision_recall_curve, average_precision_score,
    classification_report
)

# SNN相关导入
import snntorch as snn
import snntorch.surrogate as surrogate

# ========== 配置 ==========
MODEL_PATH = "/content/drive/MyDrive/trained_models/best_model_csnn_improved.pt"
FIRE_DIR = "/content/drive/MyDrive/processed_audio_validation/processed_audio_fire"
NONFIRE_DIR = "/content/drive/MyDrive/processed_audio_validation/processed_audio_nonfire"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== CSNN模型定义 ==========
class CSNN(nn.Module):
    def __init__(self, T=15, beta=0.5, num_classes=2):
        super(CSNN, self).__init__()
        self.T = T

        # BatchNorm层
        self.batch_norm_in = nn.BatchNorm2d(num_features=1)

        # 第一个卷积层 + LIF1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=0)
        self.lif1 = snn.Leaky(beta=beta, threshold=1.0, spike_grad=surrogate.fast_sigmoid())
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第二个卷积层 + LIF2
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.lif2 = snn.Leaky(beta=beta, threshold=1.0, spike_grad=surrogate.fast_sigmoid())
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 第三个卷积层 + LIF3
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0)
        self.lif3 = snn.Leaky(beta=beta, threshold=1.0, spike_grad=surrogate.fast_sigmoid())
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 全连接层
        self.fc_input_size = None
        self.fc1 = None
        self.lif4 = snn.Leaky(beta=beta, threshold=1.0, spike_grad=surrogate.fast_sigmoid())
        self.fc2 = nn.Linear(128, num_classes)
        self.lif5 = snn.Leaky(beta=beta, threshold=1.0, spike_grad=surrogate.fast_sigmoid())

    def _initialize_fc_layers(self, input_shape, device):
        with torch.no_grad():
            x = torch.randn(1, *input_shape, device=device)
            x = self.batch_norm_in(x)
            x = self.conv1(x)
            x = self.pool1(x)
            x = self.conv2(x)
            x = self.pool2(x)
            x = self.conv3(x)
            x = self.pool3(x)
            self.fc_input_size = x.view(1, -1).shape[1]
        self.fc1 = nn.Linear(self.fc_input_size, 128).to(device)

    def forward(self, x):
        if self.fc1 is None:
            self._initialize_fc_layers(x.shape[1:], x.device)

        batch_size = x.shape[0]

        # 初始化膜电位
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        spk_out_sum = torch.zeros(batch_size, 2, device=x.device)

        for t in range(self.T):
            x_t = x
            x_t = self.batch_norm_in(x_t)

            x_t = self.conv1(x_t)
            spk1, mem1 = self.lif1(x_t, mem1)
            spk1 = self.pool1(spk1)

            x_t = self.conv2(spk1)
            spk2, mem2 = self.lif2(x_t, mem2)
            spk2 = self.pool2(spk2)

            x_t = self.conv3(spk2)
            spk3, mem3 = self.lif3(x_t, mem3)
            spk3 = self.pool3(spk3)

            spk3_flat = spk3.view(batch_size, -1)

            x_t = self.fc1(spk3_flat)
            spk4, mem4 = self.lif4(x_t, mem4)

            x_t = self.fc2(spk4)
            spk5, mem5 = self.lif5(x_t, mem5)

            spk_out_sum += spk5

        spike_rates = spk_out_sum / self.T
        return spike_rates

# ========== 音频预处理 ==========
def preprocess_audio(file_path, sample_rate=16000, n_mels=64):
    # 读取音频
    waveform, sr = torchaudio.load(file_path)

    # 重采样
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)

    # 转单声道
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # 计算Mel频谱图
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=1024,
        hop_length=320,
        n_mels=n_mels,
        window_fn=torch.hann_window
    )

    amp_to_db = torchaudio.transforms.AmplitudeToDB(stype='power')

    mel_spec = mel_transform(waveform)
    mel_spec_db = amp_to_db(mel_spec)

    return mel_spec_db.unsqueeze(0)

# ========== 智能阈值搜索 ==========
def find_best_threshold(fire_rates, true_labels):
    """
    智能搜索最佳阈值，优先考虑F1分数，然后是准确率
    """
    thresholds = np.arange(0.01, 0.99, 0.01)
    best_threshold = 0.5
    best_score = 0.0
    results = []

    fire_rates = np.array(fire_rates)
    true_labels = np.array(true_labels)

    for threshold in thresholds:
        predictions = (fire_rates > threshold).astype(int)

        # 计算混淆矩阵元素
        tp = np.sum((predictions == 1) & (true_labels == 1))
        tn = np.sum((predictions == 0) & (true_labels == 0))
        fp = np.sum((predictions == 1) & (true_labels == 0))
        fn = np.sum((predictions == 0) & (true_labels == 1))

        # 计算指标
        accuracy = (tp + tn) / len(true_labels) if len(true_labels) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        # 综合分数：F1分数为主，准确率为辅
        combined_score = f1 * 0.7 + accuracy * 0.3

        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'combined_score': combined_score,
            'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
        })

        if combined_score > best_score:
            best_score = combined_score
            best_threshold = threshold

    return best_threshold, results

# ========== 主要推理流程 ==========
def main():
    print("🚀 开始CSNN智能推理...")

    # 加载模型
    print("📁 加载模型...")
    model = CSNN(T=15, beta=0.5, num_classes=2).to(DEVICE)

    try:
        state_dict = torch.load(MODEL_PATH, map_location=DEVICE)

        # 检查是否需要初始化fc1层
        if 'fc1.weight' in state_dict:
            # 从保存的权重中获取fc1的维度信息
            fc1_weight = state_dict['fc1.weight']
            fc1_input_size = fc1_weight.shape[1]
            print(f"  从权重文件检测到FC1输入维度: {fc1_input_size}")

            # 手动初始化fc1层
            model.fc_input_size = fc1_input_size
            model.fc1 = nn.Linear(fc1_input_size, 128).to(DEVICE)

        # 加载权重
        model.load_state_dict(state_dict)
        model.eval()
        print("✅ 模型加载成功!")
    except Exception as e:
        print(f"❌ 模型加载失败: {e}")
        print("🔧 尝试兼容性加载...")

        # 尝试兼容性加载：忽略fc1层，让模型自动初始化
        try:
            state_dict = torch.load(MODEL_PATH, map_location=DEVICE)

            # 移除fc1相关的键
            keys_to_remove = [key for key in state_dict.keys() if key.startswith('fc1.')]
            for key in keys_to_remove:
                print(f"  移除权重键: {key}")
                del state_dict[key]

            # 加载剩余权重
            model.load_state_dict(state_dict, strict=False)
            model.eval()
            print("✅ 兼容性加载成功! (fc1层将自动初始化)")

        except Exception as e2:
            print(f"❌ 兼容性加载也失败: {e2}")
            return

    # 获取文件列表
    fire_files = glob.glob(os.path.join(FIRE_DIR, "*.wav"))
    nonfire_files = glob.glob(os.path.join(NONFIRE_DIR, "*.wav"))

    if not fire_files or not nonfire_files:
        print("❌ 未找到音频文件，请检查路径设置")
        return

    print(f"📊 找到 {len(fire_files)} 个火灾音频，{len(nonfire_files)} 个非火灾音频")

    # 收集预测结果
    print("\n🔍 第一阶段：收集预测结果...")
    all_fire_rates = []
    all_true_labels = []
    all_filenames = []

    # 处理火灾音频
    print("  处理火灾音频...")
    for i, file_path in enumerate(sorted(fire_files)):
        try:
            input_tensor = preprocess_audio(file_path).to(DEVICE)
            with torch.no_grad():
                spike_rates = model(input_tensor)
                fire_rate = spike_rates[0][1].item()

            all_fire_rates.append(fire_rate)
            all_true_labels.append(1)
            all_filenames.append(os.path.basename(file_path))
            print(f"    {i+1:2d}. {os.path.basename(file_path)[:35]:35s} Fire发放率: {fire_rate:.4f}")

        except Exception as e:
            print(f"    ❌ 处理 {file_path} 时出错: {e}")

    # 处理非火灾音频
    print("  处理非火灾音频...")
    for i, file_path in enumerate(sorted(nonfire_files)):
        try:
            input_tensor = preprocess_audio(file_path).to(DEVICE)
            with torch.no_grad():
                spike_rates = model(input_tensor)
                fire_rate = spike_rates[0][1].item()

            all_fire_rates.append(fire_rate)
            all_true_labels.append(0)
            all_filenames.append(os.path.basename(file_path))
            print(f"    {i+1:2d}. {os.path.basename(file_path)[:35]:35s} Fire发放率: {fire_rate:.4f}")

        except Exception as e:
            print(f"    ❌ 处理 {file_path} 时出错: {e}")

    if not all_fire_rates:
        print("❌ 没有成功处理任何音频文件")
        return

    # 寻找最佳阈值
    print(f"\n🎯 第二阶段：智能搜索最佳阈值...")
    best_threshold, threshold_results = find_best_threshold(all_fire_rates, all_true_labels)

    # 显示最佳阈值结果
    best_result = next(r for r in threshold_results if r['threshold'] == best_threshold)
    print(f"\n✨ 最佳阈值: {best_threshold:.3f}")
    print(f"   准确率: {best_result['accuracy']:.3f}")
    print(f"   精确率: {best_result['precision']:.3f}")
    print(f"   召回率: {best_result['recall']:.3f}")
    print(f"   F1分数: {best_result['f1']:.3f}")

    # 显示Top5阈值
    sorted_results = sorted(threshold_results, key=lambda x: x['combined_score'], reverse=True)
    print(f"\n🏆 Top 5 最佳阈值:")
    print(f"{'阈值':>6} {'准确率':>6} {'精确率':>6} {'召回率':>6} {'F1':>6} {'综合分':>6}")
    print("-" * 42)
    for result in sorted_results[:5]:
        print(f"{result['threshold']:>6.3f} {result['accuracy']:>6.3f} {result['precision']:>6.3f} "
              f"{result['recall']:>6.3f} {result['f1']:>6.3f} {result['combined_score']:>6.3f}")

    # 使用最佳阈值进行最终预测
    print(f"\n🚀 第三阶段：使用最佳阈值 ({best_threshold:.3f}) 进行最终预测...")

    final_predictions = (np.array(all_fire_rates) > best_threshold).astype(int)
    correct_count = np.sum(final_predictions == np.array(all_true_labels))

    print(f"\n📋 详细预测结果:")
    for filename, fire_rate, true_label, pred_label in zip(all_filenames, all_fire_rates, all_true_labels, final_predictions):
        true_str = "FIRE  " if true_label == 1 else "NOFIRE"
        pred_str = "FIRE  " if pred_label == 1 else "NOFIRE"
        status = "✅" if pred_label == true_label else "❌"
        print(f"{status} [{true_str}] {filename[:35]:35s} → {pred_str} (发放率: {fire_rate:.4f})")

    # 性能总结
    total_samples = len(all_true_labels)
    fire_samples = sum(all_true_labels)
    nofire_samples = total_samples - fire_samples

    fire_correct = sum(1 for i in range(total_samples)
                      if all_true_labels[i] == 1 and final_predictions[i] == 1)
    nofire_correct = sum(1 for i in range(total_samples)
                        if all_true_labels[i] == 0 and final_predictions[i] == 0)

    print(f"\n" + "="*60)
    print(f"🎯 最终性能总结")
    print(f"="*60)
    print(f"总体准确率: {correct_count}/{total_samples} = {correct_count/total_samples:.2%}")
    print(f"🔥 火灾检测准确率: {fire_correct}/{fire_samples} = {fire_correct/fire_samples:.2%}")
    print(f"❄️  非火灾检测准确率: {nofire_correct}/{nofire_samples} = {nofire_correct/nofire_samples:.2%}")
    print(f"🎯 最佳阈值: {best_threshold:.3f}")
    print(f"📊 F1分数: {best_result['f1']:.3f}")

    # 生成详细报告
    print(f"\n📊 详细分类报告:")
    report = classification_report(all_true_labels, final_predictions,
                                 target_names=["NOFIRE", "FIRE"])
    print(report)

    # 生成可视化图表
    print(f"\n📈 生成分析图表...")
    generate_plots(all_fire_rates, all_true_labels, final_predictions,
                  best_threshold, threshold_results)

    print(f"\n🎉 CSNN智能推理完成！")
    print(f"💡 关键发现：使用阈值 {best_threshold:.3f} 可获得 {correct_count/total_samples:.1%} 的准确率")

def generate_plots(fire_rates, true_labels, predictions, best_threshold, threshold_results):
    """生成分析图表"""

    # 1. 混淆矩阵
    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["NOFIRE", "FIRE"], yticklabels=["NOFIRE", "FIRE"])
    plt.title(f"混淆矩阵 (阈值: {best_threshold:.3f})")
    plt.xlabel("预测")
    plt.ylabel("真实")
    plt.tight_layout()
    plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 2. ROC曲线
    fpr, tpr, _ = roc_curve(true_labels, fire_rates)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f"ROC AUC = {roc_auc:.3f}", linewidth=2)
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    plt.title("ROC曲线")
    plt.xlabel("假正率")
    plt.ylabel("真正率")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("roc_curve_optimized.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 3. 发放率分布
    fire_rates_fire = [fire_rates[i] for i in range(len(fire_rates)) if true_labels[i] == 1]
    fire_rates_nofire = [fire_rates[i] for i in range(len(fire_rates)) if true_labels[i] == 0]

    plt.figure(figsize=(10, 6))
    plt.hist(fire_rates_nofire, bins=15, alpha=0.7, label='非火灾样本', color='blue', density=True)
    plt.hist(fire_rates_fire, bins=15, alpha=0.7, label='火灾样本', color='red', density=True)
    plt.axvline(x=best_threshold, color='black', linestyle='--', linewidth=2,
               label=f'最佳阈值 = {best_threshold:.3f}')
    plt.axvline(x=0.5, color='gray', linestyle=':', alpha=0.7, label='默认阈值 = 0.5')

    plt.xlabel('Fire发放率')
    plt.ylabel('密度')
    plt.title('Fire发放率分布')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("fire_rate_distribution.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 4. 阈值性能曲线
    thresholds = [r['threshold'] for r in threshold_results]
    accuracies = [r['accuracy'] for r in threshold_results]
    f1_scores = [r['f1'] for r in threshold_results]

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, accuracies, label='准确率', linewidth=2, color='blue')
    plt.plot(thresholds, f1_scores, label='F1分数', linewidth=2, color='red')
    plt.axvline(x=best_threshold, color='black', linestyle='--',
               label=f'最佳阈值 = {best_threshold:.3f}')
    plt.axvline(x=0.5, color='gray', linestyle=':', alpha=0.7, label='默认阈值 = 0.5')

    plt.xlabel('阈值')
    plt.ylabel('分数')
    plt.title('阈值 vs 性能')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("threshold_performance.png", dpi=300, bbox_inches='tight')
    plt.close()

    print("  ✅ 图表已保存:")
    print("     - confusion_matrix_optimized.png")
    print("     - roc_curve_optimized.png")
    print("     - fire_rate_distribution.png")
    print("     - threshold_performance.png")

if __name__ == "__main__":
    main()

🚀 开始CSNN智能推理...
📁 加载模型...
  从权重文件检测到FC1输入维度: 3456
✅ 模型加载成功!
📊 找到 57 个火灾音频，32 个非火灾音频

🔍 第一阶段：收集预测结果...
  处理火灾音频...
     1. fire_valid_01.wav                   Fire发放率: 0.8667
     2. fire_valid_02.wav                   Fire发放率: 0.7333
     3. fire_valid_03.wav                   Fire发放率: 0.6667
     4. fire_valid_04.wav                   Fire发放率: 0.7333
     5. fire_valid_05.wav                   Fire发放率: 0.9333
     6. fire_valid_06.wav                   Fire发放率: 0.8000
     7. fire_valid_07.wav                   Fire发放率: 0.8667
     8. fire_valid_08.wav                   Fire发放率: 0.6667
     9. fire_valid_09.wav                   Fire发放率: 0.8000
    10. fire_valid_10.wav                   Fire发放率: 0.8000
    11. fire_valid_11.wav                   Fire发放率: 0.9333
    12. fire_valid_12.wav                   Fire发放率: 0.8667
    13. fire_valid_13.wav                   Fire发放率: 0.9333
    14. fire_valid_14.wav                   Fire发放率: 0.7333
    15. fire_valid_15.wav                   Fi

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.savefig("confusion_matrix_optimized.png", dpi=300, bbox_inches='tight')
  plt.tigh

  ✅ 图表已保存:
     - confusion_matrix_optimized.png
     - roc_curve_optimized.png
     - fire_rate_distribution.png
     - threshold_performance.png

🎉 CSNN智能推理完成！
💡 关键发现：使用阈值 0.410 可获得 96.6% 的准确率
