In [None]:
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

In [None]:
RESULT_DIR = "./figures_correlation"
os.makedirs(RESULT_DIR, exist_ok=True)

# method = "pearson"
# method = "spearman"
method = "kendall"

In [None]:
plt.rcParams['font.family'] = 'Arial'
plt.rcParams.update({'font.size': 16})

In [None]:
title_dict = {
    "attribution_localization": "Attribution Localization",
    "auc": "AUC",
    "pointing_game": "Pointing Game",
    "relevance_rank_accuracy": "Relevance Rank Accuracy",
    "region_perturbation_morf": "Region Perturbation (MoRF)",
    "region_perturbation_lerf": "Region Perturbation (LeRF)",
    "faithfulness_correlation": "Faithfulness Correlation"
}

In [None]:
MITDB_RESULT = "results_evaluation_localization/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20/_final_results.csv"
mitdb_df_l = pd.read_csv(MITDB_RESULT, index_col=0)
SVDB_RESULT = "results_evaluation_localization/svdb_resnet18_7_bs32_lr1e-2_wd1e-4_ep20/_final_results.csv"
svdb_df_l = pd.read_csv(SVDB_RESULT, index_col=0)
INCARTDB_RESULT = "results_evaluation_localization/incartdb_resnet18_7_bs32_lr1e-3_wd1e-4_ep20/_final_results.csv"
incartdb_df_l = pd.read_csv(INCARTDB_RESULT, index_col=0)
ICENTIA_RESULT = "results_evaluation_localization/icentia11k_resnet18_7_bs32_lr1e-3_wd1e-4_ep20/_final_results.csv"
icentia_df_l = pd.read_csv(ICENTIA_RESULT, index_col=0)

MITDB_RESULT_F = "results_evaluation_faithfulness/mitdb_resnet18_7_bs32_lr5e-2_wd1e-4_ep20/_final_results.csv"
mitdb_df_f = pd.read_csv(MITDB_RESULT_F, index_col=0)
SVDB_RESULT_F = "results_evaluation_faithfulness/svdb_resnet18_7_bs32_lr1e-2_wd1e-4_ep20/_final_results.csv"
svdb_df_f = pd.read_csv(SVDB_RESULT_F, index_col=0)
INCARTDB_RESULT_F = "results_evaluation_faithfulness/incartdb_resnet18_7_bs32_lr1e-3_wd1e-4_ep20/_final_results.csv"
incartdb_df_f = pd.read_csv(INCARTDB_RESULT_F, index_col=0)
ICENTIA_RESULT_F = "results_evaluation_faithfulness/icentia11k_resnet18_7_bs32_lr1e-3_wd1e-4_ep20/_final_results.csv"
icentia_df_f = pd.read_csv(ICENTIA_RESULT_F, index_col=0)


metrics_l = ["attribution_localization", "auc", "pointing_game", "relevance_rank_accuracy"]
metrics_f = ["region_perturbation_morf", "region_perturbation_lerf", "faithfulness_correlation"]

dfs_l = {
    "mitdb": mitdb_df_l,
    "svdb": svdb_df_l,
    "incartdb": incartdb_df_l,
    "icentia": icentia_df_l,
}

dfs_f = {
    "mitdb": mitdb_df_f,
    "svdb": svdb_df_f,
    "incartdb": incartdb_df_f,
    "icentia": icentia_df_f,
}

metric_results_l = {metric: pd.concat([mitdb_df_l[f"{metric}_mean"].rename("MITDB"), svdb_df_l[f"{metric}_mean"].rename("SVDB"), incartdb_df_l[f"{metric}_mean"].rename("INCARTDB"), icentia_df_l[f"{metric}_mean"].rename("ICENTIA11K")], axis=1) for metric in metrics_l}

metric_results_f = {metric: pd.concat([
    mitdb_df_f[f"{metric}_mean"].rename("MITDB"),
    svdb_df_f[f"{metric}_mean"].rename("SVDB"),
    incartdb_df_f[f"{metric}_mean"].rename("INCARTDB"),
    icentia_df_f[f"{metric}_mean"].rename("ICENTIA11K")
], axis=1) for metric in metrics_f}


In [None]:
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(30, 16.6))
cmap = matplotlib.colormaps.get_cmap("viridis")

for idx, metric in enumerate(metrics_l):
    corr_matrix = metric_results_l[metric].corr(method=method)
    # corr_matrix = metric_results_l[metric][[not ("absolute" in s) for s in metric_results_l[metric].index]].corr(method=method)
    up_triang = np.triu(np.ones_like(corr_matrix)).astype(bool)
    ax = axs[idx%2][idx//2]
    sns.heatmap(corr_matrix, vmin=0, vmax=1, cmap="viridis", fmt=".3f", annot=True, mask=up_triang, ax=ax, cbar=False, annot_kws={"size": 22})
    mean_corr = np.mean(corr_matrix.to_numpy()[up_triang==False])
    ax.text(2, 1.5, f"{mean_corr:.3f}", fontsize=26, color=("w" if mean_corr <= 0.712 else "k"), bbox=dict(facecolor=cmap(mean_corr), boxstyle="round,pad=0.5", edgecolor="none"))
    ax.set_title(title_dict[metric], fontsize=30)
    
for idx, metric in enumerate(metrics_f):
    corr_matrix = metric_results_f[metric].corr(method=method)
    # corr_matrix = metric_results_f[metric][[not ("absolute" in s) for s in metric_results_f[metric].index]].corr(method=method)
    up_triang = np.triu(np.ones_like(corr_matrix)).astype(bool)
    ax = axs[idx%2][idx//2+2]
    sns.heatmap(corr_matrix, vmin=0, vmax=1, cmap="viridis", fmt=".3f", annot=True, mask=up_triang, ax=ax, cbar=False, annot_kws={"size": 22})
    mean_corr = np.mean(corr_matrix.to_numpy()[up_triang==False])
    ax.text(2, 1.5, f"{mean_corr:.3f}", fontsize=26, color=("w" if mean_corr <= 0.715 else "k"), bbox=dict(facecolor=cmap(mean_corr), boxstyle="round,pad=0.5", edgecolor="none"))
    ax.set_title(title_dict[metric], fontsize=30)

axs[1,3].axis("off")
# # plt.show()
fig.tight_layout()
fig.subplots_adjust(hspace=0.22)

plt.savefig(os.path.join(RESULT_DIR, f"corrleation_{method}.png"))

In [None]:
fig = plt.figure(figsize=(5,15))
sns.heatmap([[0,1],[1,0]], vmin=0, vmax=1, cmap="viridis", fmt=".3f", annot=True, cbar=True)
plt.tight_layout()
plt.savefig(os.path.join(RESULT_DIR, f"cbar.png"))

In [None]:
# method = "pearson"
# method = "spearman"
method = "kendall"

In [None]:
result_dict = {}
datasets = ["MITDB", "SVDB", "INCARTDB", "ICENTIA11K"]
metrics = metrics_l + metrics_f
for dataset in datasets:        
    dataset_result_l = pd.concat([metric_results_l[metric][dataset] for metric in metrics_l], axis=1, keys=metrics_l)
    dataset_result_f = pd.concat([metric_results_f[metric][dataset] for metric in metrics_f], axis=1, keys=metrics_f)
    dataset_result_f["region_perturbation_lerf"] = dataset_result_f["region_perturbation_lerf"] * -1
    result_dict[dataset] = pd.concat([dataset_result_l, dataset_result_f], axis=1)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(22, 22))
cmap = matplotlib.colormaps.get_cmap("viridis")

for idx, dataset in enumerate(datasets):
    corr_matrix = result_dict[dataset].corr(method=method)
    # corr_matrix = result_dict[dataset][[not ("absolute" in s) for s in result_dict[dataset].index]].corr(method=method)

    up_triang = np.triu(np.ones_like(corr_matrix)).astype(bool)
    ax = axs[idx%2][idx//2]
    g = sns.heatmap(corr_matrix, vmin=0, vmax=1, cmap="viridis", fmt=".3f", annot=True, mask=up_triang, ax=ax, cbar=False, annot_kws={"size": 18})
    g.set_xticklabels(g.get_xticklabels(), rotation=45, ha="right")
    ax.set_title(dataset, fontsize=24)

fig.tight_layout()
# fig.subplots_adjust(hspace=0.22)
plt.savefig(os.path.join(RESULT_DIR, f"metric_corrleation_{method}.png"))