## Drawing adi

In [5]:
import os
import sys

sys.path.append("..")

import re
import glob
from utils.eval_utils import class_attribution_metric, attribution_metric
from utils import RIVAL10_constants

spss = r"^eval_spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)$"
cspss = r"^eval_spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)_([\d.]+)$"
# spss_softmax = r"^spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)"
# cspss_softmax = (
#     r"^cspss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)_([\d.]+)"
# )


def collect_pt_files(folder_path, regax_str: str):
    """
    收集指定文件夹下所有符合条件的子文件夹中的 .pt 文件，并按三个浮点数组织成字典。

    参数:
        folder_path (str): 根文件夹路径。

    返回:
        dict: 格式为 {(float, float, float): {pt_file_name: pt_file_path, ...}, ...}
    """
    # 使用 glob 匹配所有符合条件的子文件夹
    pattern = os.path.join(folder_path, "*")
    matched_folders = glob.glob(pattern)

    # 定义正则表达式提取三个浮点数
    regex = re.compile(regax_str)

    # 创建结果字典
    result_dict = {}

    # 遍历匹配的文件夹
    for folder in matched_folders:
        # 提取三个浮点数
        match = regex.search(os.path.basename(folder))
        if match:
            floats = tuple(map(float, match.groups()))  # 转换为浮点数元组

            # 构建 evaluations 文件夹路径
            print(os.path.basename(folder))
            evaluations_path = os.path.join(folder, "evaluations")

            # 检查 evaluations 文件夹是否存在
            if os.path.exists(evaluations_path) and os.path.isdir(evaluations_path):
                # 获取 evaluations 文件夹下的所有 .pt 文件
                pt_files = glob.glob(os.path.join(evaluations_path, "*.pt"))

                # 将 .pt 文件以文件名（不含扩展名）为键，存储到字典中
                pt_dict = {}
                for pt_file in pt_files:
                    file_name = os.path.splitext(os.path.basename(pt_file))[
                        0
                    ]  # 去掉扩展名
                    pt_dict[file_name] = pt_file  # 存储文件路径

                # 将结果存储到主字典中
                result_dict[floats] = pt_dict

    sorted_result_dict = dict(sorted(result_dict.items()))

    return sorted_result_dict


def collect_acc_files(folder_path, regax_str: str):
    """
    收集指定文件夹下所有符合条件的子文件夹中的 .pt 文件，并按三个浮点数组织成字典。

    参数:
        folder_path (str): 根文件夹路径。

    返回:
        dict: 格式为 {(float, float, float): exp_log.log
    """
    # 使用 glob 匹配所有符合条件的子文件夹
    pattern = os.path.join(folder_path, "*")
    matched_folders = glob.glob(pattern)

    # 定义正则表达式提取三个浮点数
    regex = re.compile(regax_str)

    # 创建结果字典
    result_dict = {}

    # 遍历匹配的文件夹
    for folder in matched_folders:
        # 提取三个浮点数
        match = regex.search(os.path.basename(folder))
        if match:
            floats = tuple(map(float, match.groups()))  # 转换为浮点数元组

            # 构建 evaluations 文件夹路径
            print(os.path.basename(folder))
            log_path = os.path.join(folder, "exp_log.log")

            # 检查 evaluations 文件夹是否存在
            if os.path.exists(log_path) and not os.path.isdir(log_path):
                result_dict[floats] = log_path

    sorted_result_dict = dict(sorted(result_dict.items()))

    return sorted_result_dict


import re


def extract_last_accuracies(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    # 正则表达式匹配 "Val Class Accuracy = xxx.xx" 和 "Val Concept Accuracy = xxx.xx"
    pattern = r"Val Class Accuracy = (\d+\.\d+).*?Val Concept Accuracy = (\d+\.\d+)"

    matches = re.findall(pattern, text)

    if matches:
        # 取最后一个匹配的结果
        last_class_acc, last_concept_acc = matches[-1]
        return float(last_class_acc), float(last_concept_acc)
    else:
        return None  # 没有找到匹配项

In [6]:
folder_path = "../outputs/lambda_ablation"
result = collect_pt_files(folder_path, spss)

adi_txt = []

for floats, pt_dict in result.items():
    # print(f"Floats: {floats}")
    for file_name, file_path in pt_dict.items():
        if file_name.startswith("classes"):
            metric = class_attribution_metric("", 0)
            metric.load_from_path(file_path)
            # print(metric.format_output(RIVAL10_constants._ALL_CLASSNAMES, latex=True))
        elif file_name == "concepts_segmentation_metric":
            metric = attribution_metric("", 0, 0)
            metric.load_from_path(file_path)
            # print(f"1.0\t1e4\t{floats[2]}\t0.0\t{adi_pack["avg_drop"].item():.2f}\t{adi_pack["avg_gain"]:.2f}\t{adi_pack["avg_inc"]:.2f}")
            adi_txt.append(f"1.0\t1e4\t{floats[2]}\t0.0\t" +
                metric.format_output(
                    RIVAL10_constants._ALL_CLASSNAMES,
                    RIVAL10_constants._ALL_ATTRS,
                    latex=True,
                    sep="\t"
                )
            )

folder_path = "../outputs/lambda_ablation"
result = collect_pt_files(folder_path, cspss)

for floats, pt_dict in result.items():
    # print(f"Floats: {floats}")
    for file_name, file_path in pt_dict.items():
        if file_name.startswith("classes"):
            metric = class_attribution_metric("", 0)
            metric.load_from_path(file_path)
            # print(metric.format_output(RIVAL10_constants._ALL_CLASSNAMES, latex=True))
        elif file_name == "concepts_segmentation_metric":
            metric = attribution_metric("", 0, 0)
            metric.load_from_path(file_path)
            # print(f"1.0\t1e4\t{floats[2]}\t0.0\t{adi_pack["avg_drop"].item():.2f}\t{adi_pack["avg_gain"]:.2f}\t{adi_pack["avg_inc"]:.2f}")
            adi_txt.append(f"1.0\t1e4\t{floats[2]}\t1.0\t" +
                metric.format_output(
                    RIVAL10_constants._ALL_CLASSNAMES,
                    RIVAL10_constants._ALL_ATTRS,
                    latex=True,
                    sep="\t"
                )
            )
            
print("\n".join(adi_txt))

eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.1
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_5.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.5
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_10.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.5
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_2.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_10.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_2.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.5_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.1_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.5_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_5.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.0_1.0
1.0	1e4	0.0	0.0		21.72		30.85		33.86	
1.0	1e4	0.1	0.0		22.66		31.29		37.04	
1.0	1e4	0.5	0.0		15.73		23.51		31.79	
1.0	1e4	1.0	0.0		15.58		24.01		40.72

In [3]:
import re
import os
import glob

spss = r"^eval_spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)$"
cspss = r"^eval_spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)_([\d.]+)$"
spss_softmax = r"^eval_spss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)$"
cspss_softmax = (
    r"^eval_cspss_vl_cbm_train_simple_concepts_([\d.]+)_([\d.]+)_([\d.]+)_([\d.]+)$"
)


def collect_pt_files(folder_path, regax_str: str):
    """
    收集指定文件夹下所有符合条件的子文件夹中的 .pt 文件，并按三个浮点数组织成字典。

    参数:
        folder_path (str): 根文件夹路径。

    返回:
        dict: 格式为 {(float, float, float): {pt_file_name: pt_file_path, ...}, ...}
    """
    # 使用 glob 匹配所有符合条件的子文件夹
    pattern = os.path.join(folder_path, "*")
    matched_folders = glob.glob(pattern)

    # 定义正则表达式提取三个浮点数
    regex = re.compile(regax_str)

    # 创建结果字典
    result_dict = {}

    # 遍历匹配的文件夹
    for folder in matched_folders:
        # 提取三个浮点数
        match = regex.search(os.path.basename(folder))
        if match:
            floats = tuple(map(float, match.groups()))  # 转换为浮点数元组

            # 构建 evaluations 文件夹路径
            print(os.path.basename(folder))
            evaluations_path = os.path.join(folder, "evaluations")

            # 检查 evaluations 文件夹是否存在
            if os.path.exists(evaluations_path) and os.path.isdir(evaluations_path):
                # 获取 evaluations 文件夹下的所有 .pt 文件
                pt_files = glob.glob(os.path.join(evaluations_path, "*.pt"))

                # 将 .pt 文件以文件名（不含扩展名）为键，存储到字典中
                pt_dict = {}
                for pt_file in pt_files:
                    file_name = os.path.splitext(os.path.basename(pt_file))[
                        0
                    ]  # 去掉扩展名
                    pt_dict[file_name] = pt_file  # 存储文件路径

                # 将结果存储到主字典中
                result_dict[floats] = pt_dict

    sorted_result_dict = dict(sorted(result_dict.items()))

    return sorted_result_dict

In [8]:
import torch

original_txt = """1.0	1e4	0.0	0.0		21.72		30.85		33.86	
1.0	1e4	0.1	0.0		22.66		31.29		37.04	
1.0	1e4	0.5	0.0		15.73		23.51		31.79	
1.0	1e4	1.0	0.0		15.58		24.01		40.72	
1.0	1e4	1.5	0.0		14.71		22.94		41.73	
1.0	1e4	2.0	0.0		13.93		21.94		42.16	
1.0	1e4	5.0	0.0		16.60		24.87		44.45	
1.0	1e4	10.0	0.0		12.61		19.64		43.26	
1.0	1e4	0.0	1.0		22.81		32.08		36.79	
1.0	1e4	0.1	1.0		20.41		28.51		35.78	
1.0	1e4	0.5	1.0		17.67		26.09		35.07	
1.0	1e4	1.0	1.0		17.15		25.77		40.53	
1.0	1e4	1.5	1.0		14.97		23.30		42.16	
1.0	1e4	2.0	1.0		23.05		31.77		41.90	
1.0	1e4	5.0	1.0		15.74		23.89		45.02	
1.0	1e4	10.0	1.0		11.62		18.42		44.13   """.split("\n")

folder_path = "../outputs/lambda_ablation"
result = collect_pt_files(folder_path, spss)

for idx1, (floats, acc_path) in enumerate(result.items()):
    # print(f"Floats: {floats}")
    # print(f"path: {acc_path}")
    adi_pack = torch.load(acc_path["adi_pack"])
    # print(adi_pack)
    original_txt[idx1] += f"\t{adi_pack["avg_drop"].item() * 100:.2f}\t{adi_pack["avg_inc"].item() * 100:.2f}\t{adi_pack["avg_gain"].item() * 100:.2f}"
    
result = collect_pt_files(folder_path, cspss)

for idx2, (floats, acc_path) in enumerate(result.items()):
    # print(f"Floats: {floats}")
    # print(f"path: {acc_path}")
    adi_pack = torch.load(acc_path["adi_pack"])
    # print(adi_pack)
    original_txt[idx1 + 1 + idx2] += f"\t{adi_pack["avg_drop"].item() * 100:.2f}\t{adi_pack["avg_inc"].item() * 100:.2f}\t{adi_pack["avg_gain"].item() * 100:.2f}"
    
print("\n".join(original_txt))

eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.1
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_5.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.5
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_10.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.5
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_2.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_10.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_2.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.5_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.1_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_1.5_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_5.0_1.0
eval_spss_vl_cbm_train_simple_concepts_1.0_1.0_0.0_1.0
1.0	1e4	0.0	0.0		21.72		30.85		33.86		3.67	50.57	13.57
1.0	1e4	0.1	0.0		22.66		31.29		37.04		3.47	50.03	14.16
1.0	1e4	0.5	0.0		15.73		23.51		31.79		3.