In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns 
import pandas as pd 
import re

In [2]:
def read(file_path,model:str,encoder:str):
    # 1. 从txt文件中读取内容
    # file_path = "../exp/encoder_com/steatosis/abmil/uni/log_metric.txt"  # 替换为你的文件路径
    with open(file_path, "r", encoding="utf-8") as file:
        lines = [line.strip() for line in file.readlines() if line.strip()]

    # 2. 提取表头和有效数据行
    header = ['FOLD','ACC','AUC','Kappa','F1','Precision','Recall']  # 提取表头
    data_lines = lines  # 跳过表头和分隔行（例如"|---|---|..."）

    # 3. 正则表达式提取数值
    pattern = r"(\d+\.\d+)"  # 匹配浮点数
    processed_data = []
    for line in data_lines:
        # 提取FOLD编号和所有指标数值
        parts = line.split(",")  # 去除首尾空列
        fold_number = re.search(r"FOLD: (\d+)", parts[0]).group(1)
        values = [re.search(pattern, col).group(1) for col in parts]
        processed_data.append([fold_number] + values)

    # 4. 创建DataFrame并转换类型
    columns = ["Fold"] + header[1:]  # 列名：Fold, ACC, AUC, Kappa, F1, Precision, Recall
    df = pd.DataFrame(processed_data, columns=columns)
    df = df.apply(pd.to_numeric)  # 转换为数值类型
    df['classifier'] = [model]*len(df)
    df['encoder'] = [encoder]*len(df)
    return df

In [3]:
df_1 = read('runs/2025_04_06_10_42_28/log_metric.txt','diff-5','uni')
df_2 = read('runs/2025_04_06_10_43_46/log_metric.txt','diff-9','uni')

In [4]:
df_1.mean(), df_2.mean()

  df_1.mean(), df_2.mean()


(Fold         2.000000
 ACC          0.838889
 AUC          0.883484
 Kappa        0.896304
 F1           0.836145
 Precision    0.848910
 Recall       0.824768
 dtype: float64,
 Fold         2.000000
 ACC          0.840741
 AUC          0.873831
 Kappa        0.891913
 F1           0.832983
 Precision    0.834461
 Recall       0.831581
 dtype: float64)

In [10]:
combined_df = pd.concat([df_1,df_2], axis=0, ignore_index=True)

In [14]:
sns.catplot(data=combined_df, 
            x="classifier", 
            y='ACC', 
            hue="classifier", 
            kind="bar",
            # order=order,
            palette='Set2')
# plt.xticks(['mean','max','abmil','mambamil','wikg'])
plt.ylim([0,1])

(0.0, 1.0)