In [1]:
import torch
from dataset import get_loader
from models.net import get_model
from constant import IMPLEMENTED_NETS, SUPPORTED_TASKS, NUM_TO_FOUR_CLASS
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

In [2]:
def plot_confusion_matrix(cm, classes, title):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(f'figs/confusion_matrix_{title}.png')
    plt.close()


def calculate_metrics(cm):
    n_classes = cm.shape[0]
    fpr = []
    fnr = []
    precision = []
    recall = []
    f1 = []
    
    for i in range(n_classes):
        tp = cm[i, i]
        fp = np.sum(cm[:, i]) - tp
        fn = np.sum(cm[i, :]) - tp
        tn = np.sum(cm) - (fp + fn + tp)
        
        fpr.append(fp / (fp + tn))
        fnr.append(fn / (fn + tp))
        
        precision.append(tp / (tp + fp) if (tp + fp) > 0 else 0)
        recall.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
        f1.append(2 * precision[-1] * recall[-1] / (precision[-1] + recall[-1]) if (precision[-1] + recall[-1]) > 0 else 0)
    
    return fpr, fnr, precision, recall, f1

In [3]:
task = 'four-class'
device = torch.device('cuda')
for arch in IMPLEMENTED_NETS:
    net = get_model(arch, SUPPORTED_TASKS[task], False)
    net.load_state_dict(
        torch.load(f'/home/zg34/Desktop/MathWorks-Radar-Drone-Project/logs/pth_models/{task}/{task}-{arch}.pth'))
    net = net.eval().to(device)
    loader = get_loader('/home/zg34/datasets/drone_project/eval_4cls.csv',
                        '/home/zg34/datasets/drone_project/data',
                        task,
                        batch_size=64, shuffle=False, drop_last=False, num_workers=8)
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = net(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)

    # 获取类别名称（假设它们是 0, 1, 2, 3）
    class_names = ['drone', 'bird', 'cluster', 'noise']

    # 绘制并保存混淆矩阵
    plot_confusion_matrix(cm, class_names, f'{arch}')

    fpr, fnr, precision, recall, f1 = calculate_metrics(cm)
    
    print(f"\nresults for {arch}:")
    # for i in range(len(class_names)):
    #     print(f"Class {i}:")
    #     print(f"  Precision: {precision[i]:.4f}")
    #     print(f"  Recall: {recall[i]:.4f}")
    #     print(f"  F1-score: {f1[i]:.4f}")
    #     print(f"  FPR: {fpr[i]:.4f}")
    #     print(f"  FNR: {fnr[i]:.4f}")
    print(f"\nAverage metrics:")
    print(f"  Precision: {np.mean(precision):.4f}")
    print(f"  Recall: {np.mean(recall):.4f}")
    print(f"  F1-score: {np.mean(f1):.4f}")
    print(f"  FPR: {np.mean(fpr):.4f}")
    print(f"  FNR: {np.mean(fnr):.4f}")


results for resnet18:

Average metrics:
  Precision: 0.9468
  Recall: 0.9491
  F1-score: 0.9467
  FPR: 0.0173
  FNR: 0.0509

results for resnet50:

Average metrics:
  Precision: 0.9426
  Recall: 0.9376
  F1-score: 0.9347
  FPR: 0.0212
  FNR: 0.0624

results for convnext_tiny:

Average metrics:
  Precision: 0.9387
  Recall: 0.9159
  F1-score: 0.9145
  FPR: 0.0291
  FNR: 0.0841

results for convnext_base:

Average metrics:
  Precision: 0.9395
  Recall: 0.9225
  F1-score: 0.9211
  FPR: 0.0268
  FNR: 0.0775

results for efficientnet_v2_s:

Average metrics:
  Precision: 0.9765
  Recall: 0.9742
  F1-score: 0.9745
  FPR: 0.0089
  FNR: 0.0258

results for efficientnet_v2_m:

Average metrics:
  Precision: 0.9930
  Recall: 0.9937
  F1-score: 0.9933
  FPR: 0.0022
  FNR: 0.0063

results for resnext50_32x4d:

Average metrics:
  Precision: 0.9294
  Recall: 0.9159
  F1-score: 0.9109
  FPR: 0.0287
  FNR: 0.0841

results for alexnet:

Average metrics:
  Precision: 0.9541
  Recall: 0.9440
  F1-score: 0