# Single

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

model_name = "alexnet"
model_name_title = "AlexNet"

# 读取真实标签和预测结果
all_labels = []
all_preds = []
with open(f"./logs/best_{model_name}_model.txt", "r") as f:
    for line in f:
        label, pred = map(int, line.strip().split())
        all_labels.append(label)
        all_preds.append(pred)

# 计算混淆矩阵
conf_matrix = confusion_matrix(all_labels, all_preds)

# 绘制混淆矩阵热力图
plt.figure(figsize=(10, 8))  # 调整图形大小
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", cbar_kws={'label': 'Count'}, annot_kws={"size": 15})
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title(f"{model_name_title} Confusion Matrix")

# 保存图像
plt.savefig(f"./figures/{model_name}_confusion_matrix.png", dpi=500)  # 保存为高分辨率图片
plt.show()

In [None]:
import re

# 读取日志文件
with open(f"./logs/{model_name}.log", 'r') as file:
    log_data = file.read()

# 提取训练指标
metrics = {
    "epoch": [],
    "training_loss": [],
    "training_accuracy": [],
    "validating_loss": [],
    "validating_accuracy": [],
    "precision": [],
    "recall": [],
    "f1_score": [],
    "auc": []
}

# 使用正则表达式提取数据
for line in log_data.splitlines():
    match = re.search(
        r'Epoch: (\d+), Training loss: ([\d.]+), Training accuracy: ([\d.]+), Validating loss: ([\d.]+), Validating accuracy: ([\d.]+), Precision: ([\d.]+), Recall: ([\d.]+), f1 score: ([\d.]+), auc: ([\d.]+)',
        line)
    if match:
        metrics["epoch"].append(int(match.group(1)))
        metrics["training_loss"].append(float(match.group(2)))
        metrics["training_accuracy"].append(float(match.group(3)))
        metrics["validating_loss"].append(float(match.group(4)))
        metrics["validating_accuracy"].append(float(match.group(5)))
        metrics["precision"].append(float(match.group(6)))
        metrics["recall"].append(float(match.group(7)))
        metrics["f1_score"].append(float(match.group(8)))
        metrics["auc"].append(float(match.group(9)))

In [None]:
import matplotlib.pyplot as plt

epochs = metrics["epoch"]

# 创建一个包含 6 个子图的图形（2 行 3 列）
fig, axs = plt.subplots(2, 3, figsize=(18, 10))

fig.suptitle(f"{model_name_title} Indicator Chart", fontsize=16)

# 1. 损失变化图
axs[0, 0].plot(epochs, metrics["training_loss"], label='Training Loss', color='blue')
axs[0, 0].plot(epochs, metrics["validating_loss"], label='Validating Loss', color='orange')
axs[0, 0].set_title('Loss')
axs[0, 0].set_xlabel('Epochs')
axs[0, 0].set_ylabel('Loss')
axs[0, 0].legend()

# 2. 准确率变化图
axs[0, 1].plot(epochs, metrics["training_accuracy"], label='Training Accuracy', color='green')
axs[0, 1].plot(epochs, metrics["validating_accuracy"], label='Validating Accuracy', color='red')
axs[0, 1].set_title('Accuracy')
axs[0, 1].set_xlabel('Epochs')
axs[0, 1].set_ylabel('Accuracy')
axs[0, 1].legend()

# 3. 精确率变化图
axs[0, 2].plot(epochs, metrics["precision"], label='Precision', color='purple')
axs[0, 2].set_title('Precision')
axs[0, 2].set_xlabel('Epochs')
axs[0, 2].set_ylabel('Precision')
axs[0, 2].legend()

# 4. 召回率变化图
axs[1, 0].plot(epochs, metrics["recall"], label='Recall', color='brown')
axs[1, 0].set_title('Recall')
axs[1, 0].set_xlabel('Epochs')
axs[1, 0].set_ylabel('Recall')
axs[1, 0].legend()

# 5. F1分数变化图
axs[1, 1].plot(epochs, metrics["f1_score"], label='F1 Score', color='darkviolet')
axs[1, 1].set_title('F1 Score')
axs[1, 1].set_xlabel('Epochs')
axs[1, 1].set_ylabel('F1 Score')
axs[1, 1].legend()

# 6. AUC变化图
axs[1, 2].plot(epochs, metrics["auc"], label='AUC', color='magenta')
axs[1, 2].set_title('AUC')
axs[1, 2].set_xlabel('Epochs')
axs[1, 2].set_ylabel('AUC')
axs[1, 2].legend()

# 调整子图之间的间距
plt.tight_layout()

# 保存图形
plt.savefig(f"./figures/{model_name}_indicator_chart.png", dpi=1000)

# 显示图形
plt.show()

# All

In [None]:
import os
import re
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# ====== 1. 配置所有模型 ======
model_list = [
    ("alexnet", "AlexNet"),
    ("googlenet", "GoogLeNet"),
    ("resnext", "ResNeXt"),
    ("resnet", "ResNet"),
    ("densenet", "DenseNet"),
    ("swint", "SwinTransformer"),
    ("mobilenet", "MobileNet"),
    ("mit", "MobileViT"),
    ("regnet", "RegNet"),
    ("convnext", "ConvNeXt"),
    ("efficientnet", "EfficientNet"),
    ("ev2", "EfficientNetV2"),
    ("vit", "VisionTransformer"),
    ("xception", "Xception"),
    ("vgg", "VGGNet"),
    ("shufflenet", "ShuffleNet"),
]

os.makedirs("./figures", exist_ok=True)

# 提前编译正则
log_pattern = re.compile(
    r'Epoch: (\d+), Training loss: ([\d.]+), Training accuracy: ([\d.]+), '
    r'Validating loss: ([\d.]+), Validating accuracy: ([\d.]+), '
    r'Precision: ([\d.]+), Recall: ([\d.]+), f1 score: ([\d.]+), auc: ([\d.]+)'
)

# ====== 2. 依次处理每个模型 ======
for model_name, model_name_title in model_list:
    print(f"\n=== Processing {model_name_title} ({model_name}) ===")

    # ---------- 2.1 读取 best_xxx_model.txt，画混淆矩阵 ----------
    labels_file = f"./logs/best_{model_name}_model.txt"
    if not os.path.exists(labels_file):
        print(f"[WARN] {labels_file} not found, skip confusion matrix.")
    else:
        all_labels = []
        all_preds = []
        with open(labels_file, "r") as f:
            for line in f:
                # 假设每行是 “label pred”
                label, pred = map(int, line.strip().split())
                all_labels.append(label)
                all_preds.append(pred)

        conf_matrix = confusion_matrix(all_labels, all_preds)

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            conf_matrix,
            annot=True,
            fmt="d",
            cmap="Blues",
            cbar_kws={'label': 'Count'},
            annot_kws={"size": 15}
        )
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        plt.title(f"{model_name_title} Confusion Matrix")

        plt.savefig(f"./figures/{model_name}_confusion_matrix.png", dpi=500)
        plt.show()  # <--- 每个模型的混淆矩阵都会展示一张图

    # ---------- 2.2 读取 xxx.log，解析指标并画 6 子图 ----------
    log_file = f"./logs/{model_name}.log"
    if not os.path.exists(log_file):
        print(f"[WARN] {log_file} not found, skip indicator chart.")
        continue

    with open(log_file, "r") as file:
        log_data = file.read()

    metrics = {
        "epoch": [],
        "training_loss": [],
        "training_accuracy": [],
        "validating_loss": [],
        "validating_accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "auc": []
    }

    for line in log_data.splitlines():
        m = log_pattern.search(line)
        if m:
            metrics["epoch"].append(int(m.group(1)))
            metrics["training_loss"].append(float(m.group(2)))
            metrics["training_accuracy"].append(float(m.group(3)))
            metrics["validating_loss"].append(float(m.group(4)))
            metrics["validating_accuracy"].append(float(m.group(5)))
            metrics["precision"].append(float(m.group(6)))
            metrics["recall"].append(float(m.group(7)))
            metrics["f1_score"].append(float(m.group(8)))
            metrics["auc"].append(float(m.group(9)))

    if len(metrics["epoch"]) == 0:
        print(f"[WARN] No metric lines found in {log_file}, skip plotting.")
        continue

    epochs = metrics["epoch"]

    # 创建 2x3 子图
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f"{model_name_title} Indicator Chart", fontsize=16)

    # 1. Loss
    axs[0, 0].plot(
        epochs, metrics["training_loss"],
        label='Training Loss', color='blue'
    )
    axs[0, 0].plot(
        epochs, metrics["validating_loss"],
        label='Validating Loss', color='orange'
    )
    axs[0, 0].set_title('Loss')
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()

    # 2. Accuracy
    axs[0, 1].plot(
        epochs, metrics["training_accuracy"],
        label='Training Accuracy', color='green'
    )
    axs[0, 1].plot(
        epochs, metrics["validating_accuracy"],
        label='Validating Accuracy', color='red'
    )
    axs[0, 1].set_title('Accuracy')
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('Accuracy')
    axs[0, 1].legend()

    # 3. Precision
    axs[0, 2].plot(
        epochs, metrics["precision"],
        label='Precision', color='purple'
    )
    axs[0, 2].set_title('Precision')
    axs[0, 2].set_xlabel('Epochs')
    axs[0, 2].set_ylabel('Precision')
    axs[0, 2].legend()

    # 4. Recall
    axs[1, 0].plot(
        epochs, metrics["recall"],
        label='Recall', color='brown'
    )
    axs[1, 0].set_title('Recall')
    axs[1, 0].set_xlabel('Epochs')
    axs[1, 0].set_ylabel('Recall')
    axs[1, 0].legend()

    # 5. F1
    axs[1, 1].plot(
        epochs, metrics["f1_score"],
        label='F1 Score', color='darkviolet'
    )
    axs[1, 1].set_title('F1 Score')
    axs[1, 1].set_xlabel('Epochs')
    axs[1, 1].set_ylabel('F1 Score')
    axs[1, 1].legend()

    # 6. AUC
    axs[1, 2].plot(
        epochs, metrics["auc"],
        label='AUC', color='magenta'
    )
    axs[1, 2].set_title('AUC')
    axs[1, 2].set_xlabel('Epochs')
    axs[1, 2].set_ylabel('AUC')
    axs[1, 2].legend()

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # 给 suptitle 留点空间

    plt.savefig(f"./figures/{model_name}_indicator_chart.png", dpi=1000)
    plt.show()

# Merge

In [None]:
# ===== 所有模型合并成一张大图（左指标 + 右混淆矩阵） =====
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from matplotlib import gridspec

# 模型列表
model_list = [
    ("alexnet", "AlexNet"),
    ("googlenet", "GoogLeNet"),
    ("resnext", "ResNeXt"),
    ("resnet", "ResNet"),
    ("densenet", "DenseNet"),
    ("swint", "SwinTransformer"),
    ("mobilenet", "MobileNet"),
    ("mit", "MobileViT"),
    ("regnet", "RegNet"),
    ("convnext", "ConvNeXt"),
    ("efficientnet", "EfficientNet"),
    ("ev2", "EfficientNetV2"),
    ("vit", "VisionTransformer"),
    ("xception", "Xception"),
    ("vgg", "VGGNet"),
    ("shufflenet", "ShuffleNet"),
]

# 日志正则
log_pattern = re.compile(
    r'Epoch: (\d+), Training loss: ([\d.]+), Training accuracy: ([\d.]+), '
    r'Validating loss: ([\d.]+), Validating accuracy: ([\d.]+), '
    r'Precision: ([\d.]+), Recall: ([\d.]+), f1 score: ([\d.]+), auc: ([\d.]+)'
)

# ============= 设置大图尺寸 ==============
row_height = 4  # 每个模型一行的高度，你可以调成 3 或 2
fig = plt.figure(figsize=(20, row_height * len(model_list)))

# 大 GridSpec：左指标、右混淆矩阵（2 列，行数 = 模型数）
outer_gs = gridspec.GridSpec(len(model_list), 2, width_ratios=[3, 2], hspace=0.4)

# ============= 开始逐模型绘图 ==============
for r, (model_name, model_title) in enumerate(model_list):

    # =====================================================
    # 1. 读取混淆矩阵数据
    # =====================================================
    label_file = f"./logs/best_{model_name}_model.txt"
    if not os.path.exists(label_file):
        print(f"[WARN] {label_file} missing, skip model {model_name}")
        continue

    all_labels, all_preds = [], []
    with open(label_file, "r") as f:
        for line in f:
            l, p = map(int, line.strip().split())
            all_labels.append(l)
            all_preds.append(p)
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # =====================================================
    # 2. 读取 log 获取指标
    # =====================================================
    log_file = f"./logs/{model_name}.log"
    if not os.path.exists(log_file):
        print(f"[WARN] {log_file} missing, skip model {model_name}")
        continue

    metrics = {
        "epoch": [],
        "training_loss": [],
        "training_accuracy": [],
        "validating_loss": [],
        "validating_accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "auc": []
    }

    with open(log_file, "r") as f:
        for line in f:
            m = log_pattern.search(line)
            if m:
                metrics["epoch"].append(int(m.group(1)))
                metrics["training_loss"].append(float(m.group(2)))
                metrics["training_accuracy"].append(float(m.group(3)))
                metrics["validating_loss"].append(float(m.group(4)))
                metrics["validating_accuracy"].append(float(m.group(5)))
                metrics["precision"].append(float(m.group(6)))
                metrics["recall"].append(float(m.group(7)))
                metrics["f1_score"].append(float(m.group(8)))
                metrics["auc"].append(float(m.group(9)))

    if len(metrics["epoch"]) == 0:
        print(f"[WARN] no metrics for {model_name}")
        continue

    epochs = metrics["epoch"]

    # =====================================================
    # 3. 左侧：绘制 2×3 指标图（放到 outer_gs[r,0] 中）
    # =====================================================
    left_gs = gridspec.GridSpecFromSubplotSpec(
        2, 3, subplot_spec=outer_gs[r, 0], wspace=0.3, hspace=0.4
    )

    axs = np.empty((2, 3), dtype=object)
    for i in range(2):
        for j in range(3):
            axs[i, j] = fig.add_subplot(left_gs[i, j])

    # 依次画指标（保持你原来的颜色）
    axs[0, 0].plot(epochs, metrics["training_loss"], label='Training Loss', color='blue')
    axs[0, 0].plot(epochs, metrics["validating_loss"], label='Validating Loss', color='orange')
    axs[0, 0].set_title('Loss');
    axs[0, 0].legend()

    axs[0, 1].plot(epochs, metrics["training_accuracy"], label='Training Acc', color='green')
    axs[0, 1].plot(epochs, metrics["validating_accuracy"], label='Validating Acc', color='red')
    axs[0, 1].set_title('Accuracy');
    axs[0, 1].legend()

    axs[0, 2].plot(epochs, metrics["precision"], label='Precision', color='purple')
    axs[0, 2].set_title('Precision');
    axs[0, 2].legend()

    axs[1, 0].plot(epochs, metrics["recall"], label='Recall', color='brown')
    axs[1, 0].set_title('Recall');
    axs[1, 0].legend()

    axs[1, 1].plot(epochs, metrics["f1_score"], label='F1 Score', color='darkviolet')
    axs[1, 1].set_title('F1 Score');
    axs[1, 1].legend()

    axs[1, 2].plot(epochs, metrics["auc"], label='AUC', color='magenta')
    axs[1, 2].set_title('AUC');
    axs[1, 2].legend()

    axs[0, 0].set_ylabel(model_title, fontsize=12)

    # =====================================================
    # 4. 右侧：混淆矩阵
    # =====================================================
    ax_cm = fig.add_subplot(outer_gs[r, 1])
    sns.heatmap(
        conf_matrix,
        annot=True,
        fmt="d",
        cmap="Blues",
        cbar_kws={'label': 'Count'},
        annot_kws={"size": 8},
        ax=ax_cm
    )
    ax_cm.set_title(f"{model_title} Confusion Matrix")

# 5. 保存总图
save_path = "./figures/all_models_combined.png"
plt.savefig(save_path, dpi=500, bbox_inches="tight")
plt.show()
print("[INFO] Saved ->", save_path)

In [None]:
# ===== Cell 2：所有模型的 6 个验证指标对比图（三行两列），并保存 =====
import os
import re
import numpy as np
import matplotlib.pyplot as plt

model_list = [
    ("alexnet", "AlexNet"),
    ("googlenet", "GoogLeNet"),
    ("resnext", "ResNeXt"),
    ("resnet", "ResNet"),
    ("densenet", "DenseNet"),
    ("swint", "SwinTransformer"),
    ("mobilenet", "MobileNet"),
    ("mit", "MobileViT"),
    ("regnet", "RegNet"),
    ("convnext", "ConvNeXt"),
    ("efficientnet", "EfficientNet"),
    ("ev2", "EfficientNetV2"),
    ("vit", "VisionTransformer"),
    ("xception", "Xception"),
    ("vgg", "VGGNet"),
    ("shufflenet", "ShuffleNet"),
]

log_pattern = re.compile(
    r'Epoch: (\d+), Training loss: ([\d.]+), Training accuracy: ([\d.]+), '
    r'Validating loss: ([\d.]+), Validating accuracy: ([\d.]+), '
    r'Precision: ([\d.]+), Recall: ([\d.]+), f1 score: ([\d.]+), auc: ([\d.]+)'
)

metrics_all = {}  # model_name -> dict

for model_name, model_title in model_list:
    log_file = f"./logs/{model_name}.log"
    if not os.path.exists(log_file):
        print(f"[WARN] {log_file} not found, skip.")
        continue

    metrics = {
        "epoch": [],
        "training_loss": [],
        "training_accuracy": [],
        "validating_loss": [],
        "validating_accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "auc": []
    }

    with open(log_file, "r") as f:
        for line in f:
            m = log_pattern.search(line)
            if m:
                metrics["epoch"].append(int(m.group(1)))
                metrics["training_loss"].append(float(m.group(2)))
                metrics["training_accuracy"].append(float(m.group(3)))
                metrics["validating_loss"].append(float(m.group(4)))
                metrics["validating_accuracy"].append(float(m.group(5)))
                metrics["precision"].append(float(m.group(6)))
                metrics["recall"].append(float(m.group(7)))
                metrics["f1_score"].append(float(m.group(8)))
                metrics["auc"].append(float(m.group(9)))

    if len(metrics["epoch"]) == 0:
        print(f"[WARN] No metric lines in {log_file}, skip.")
        continue

    metrics_all[model_name] = {
        "title": model_title,
        **metrics
    }

if not metrics_all:
    print("No metrics loaded, please check log files.")
else:
    indicator_cfg = [
        ("validating_loss", "Validating Loss", "Loss", True),
        ("validating_accuracy", "Validating Accuracy", "Accuracy", False),
        ("precision", "Precision", "Precision", False),
        ("recall", "Recall", "Recall", False),
        ("f1_score", "F1 Score", "F1 Score", False),
        ("auc", "AUC", "AUC", False),
    ]

    fig, axs = plt.subplots(3, 2, figsize=(18, 15))
    fig.suptitle("Validation Metrics Comparison Across Models", fontsize=18)

    # 为每个模型分配颜色
    cmap = plt.get_cmap("tab20", len(model_list))

    for idx, (metric_key, title, y_label, is_loss) in enumerate(indicator_cfg):
        row = idx // 2
        col = idx % 2
        ax = axs[row, col]

        for i, (model_name, model_title) in enumerate(model_list):
            if model_name not in metrics_all:
                continue

            data = metrics_all[model_name]
            epochs = data["epoch"]
            values = data[metric_key]
            if len(epochs) == 0:
                continue

            color = cmap(i)

            # 曲线
            ax.plot(epochs, values, label=model_title, color=color)

            # 找最优点：loss 取最小，其余取最大
            values_np = np.array(values)
            if is_loss:
                best_idx = int(np.argmin(values_np))
            else:
                best_idx = int(np.argmax(values_np))

            best_x = epochs[best_idx]
            best_y = values[best_idx]

            # 标记点
            ax.scatter([best_x], [best_y], color=color, s=40)
            ax.annotate(
                f"{best_y:.4f}",
                (best_x, best_y),
                textcoords="offset points",
                xytext=(0, 6),
                ha="center",
                fontsize=7,
                color=color
            )

        ax.set_title(title)
        ax.set_xlabel("Epochs")
        ax.set_ylabel(y_label)
        ax.legend(fontsize=8)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    out_path = "./figures/all_models_validation_metrics.png"
    plt.savefig(out_path, dpi=500, bbox_inches="tight")
    plt.show()
    print(f"[INFO] saved comparison figure -> {out_path}")