In [1]:
import pickle
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from utils_martina.my_utils import *
from src.dataset.instances.graph import GraphInstance
from src.evaluation.evaluation_metric_ged import GraphEditDistanceMetric
from src.evaluation.evaluation_metric_correctness import CorrectnessMetric
from src.evaluation.evaluation_metric_fidelity import FidelityMetric
from src.utils.metrics.sparsity import SparsityMetric

#### Load data

In [2]:
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"

file_name = get_most_recent_file(eval_manager_path).split('.')[0]
print(file_name)

with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

28892-Martina


#### Get oracle metrics

In [3]:
def get_oracle_metrics(eval_manager):
    instances = eval_manager.evaluators[0].dataset.instances
    oracle = eval_manager.evaluators[0]._oracle

    grouped = {}
    for inst in instances:
        key = f"{inst.patient_id}_{inst.record_id}"
        grouped.setdefault(key, []).append(inst)

    rows = []
    for pr, inst_list in grouped.items():
        y_true = [i.label for i in inst_list]
        y_pred = [oracle.predict(i) for i in inst_list]
        rows.append({
            "patient_record": pr,
            "accuracy": round(accuracy_score(y_true, y_pred),4),
            "f1_score": round(f1_score(y_true, y_pred),4),
            "recall": round(recall_score(y_true, y_pred),4),
            "precision": round(precision_score(y_true, y_pred),4)
        })

    y_true_all = [i.label for i in instances]
    y_pred_all = [oracle.predict(i) for i in instances]
    global_row = {
        "patient_record": "ALL",
        "accuracy": round(accuracy_score(y_true_all, y_pred_all),4),
        "f1_score": round(f1_score(y_true_all, y_pred_all),4),
        "recall": round(recall_score(y_true_all, y_pred_all),4),
        "precision": round(precision_score(y_true_all, y_pred_all),4)
    }

    return pd.DataFrame([global_row] + rows)

oracle_metrics = get_oracle_metrics(eval_manager)
print(oracle_metrics)

  patient_record  accuracy  f1_score  recall  precision
0            ALL    0.5220    0.6833     1.0     0.5190
1       chb01_03    0.5171    0.6811     1.0     0.5165
2       chb01_04    0.5268    0.6855     1.0     0.5215


#### Get top $k$ counterfactuals and explainer metrics

In [4]:
ids = {}
patient_record = {}
correctness = {}
fidelity = {}
sparsity = {}
ged = {}

M_dissim = {}
M_time = {}
M_instab = {}

for i in range(len(eval_manager._evaluators)):
    # inizializzo liste vuote per l'evaluatore i
    ids[i] = []
    patient_record[i] = []
    correctness[i] = []
    fidelity[i] = []
    sparsity[i] = []
    ged[i] = []

    M_dissim[i] = []
    M_time[i] = []
    M_instab[i] = []

    cf_dict = {}

    oracle = eval_manager._evaluators[i]._oracle
    explainer = eval_manager._evaluators[i]._explainer

    list_instances = eval_manager._evaluators[i].dataset.instances
    list_instances = [instance for instance in list_instances if instance.label == 1]

    for instance in tqdm(list_instances, desc=f"Evaluator {i}"):
        lista = explainer.explain(instance, return_list=True)
        cf_dict[instance.id] = lista

        if isinstance(lista, GraphInstance): # Questi sono i casi per cui l'oracolo ha sbagliato
            counterfactual = lista
        else:
            counterfactual = lista[0][1]

        ids[i].append(instance.id)
        patient_record[i].append(f"{instance.patient_id}_{instance.record_id}")
        correctness[i].append(CorrectnessMetric().evaluate(instance, counterfactual, oracle, explainer))
        fidelity[i].append(FidelityMetric().evaluate(instance, counterfactual, oracle, explainer))
        sparsity[i].append(SparsityMetric().evaluate(instance, counterfactual))
        ged[i].append(GraphEditDistanceMetric().evaluate(instance, counterfactual))

        m_dissim, m_time, m_instab = explainer.compute_metric_components(instance, counterfactual)
        M_dissim[i].append(m_dissim)
        M_time[i].append(m_time)
        M_instab[i].append(m_instab)

    with open(f"output\cf_dict\cf_dict_{file_name}_{i}.pkl", 'wb') as f:
        pickle.dump(cf_dict, f)

Evaluator 0:   0%|          | 0/821 [00:00<?, ?it/s]

Evaluator 0:  88%|████████▊ | 719/821 [22:07<03:08,  1.85s/it]


KeyboardInterrupt: 

In [None]:
dfs = {}

for i in ids.keys():
    dfs[i] = pd.DataFrame({
        "patient_record": patient_record[i],
        "correctness": correctness[i],
        "fidelity": fidelity[i],
        "sparsity": sparsity[i],
        "ged": ged[i],
        "M_dissim": M_dissim[i],
        "M_time": M_time[i],
        "M_instab": M_instab[i]
    })

for i, df in dfs.items():
    global_row = df.drop(columns="patient_record").mean().round(4)
    global_row["patient_record"] = "ALL"
    grouped = df.groupby("patient_record").mean().round(4).reset_index()

    explainer_metrics = pd.concat([pd.DataFrame([global_row]), grouped], ignore_index=True)
    cols = ["patient_record", "correctness", "fidelity", "sparsity", "ged", "M_dissim", "M_time", "M_instab"]
    explainer_metrics = explainer_metrics[cols]

    print(f"Evaluator {i} metrics:")
    print(explainer_metrics)
    print("\n")

Evaluator 0 metrics:
  patient_record  correctness  fidelity  sparsity      ged  M_dissim  \
0            ALL       0.9102    0.9102    1.0122  49.5060    1.1638   
1       chb01_03       0.9877    0.9877    1.1227  55.4632    1.6573   
2       chb01_04       0.9661    0.9661    1.0439  51.7433    1.2546   
3       chb01_15       0.9167    0.9167    1.0076  48.5735    1.2869   
4       chb01_16       0.8288    0.8288    0.9431  46.5658    1.2631   
5       chb01_18       0.7745    0.7745    0.8848  42.8554    0.6223   
6       chb01_21       0.9582    0.9582    1.0563  51.4472    0.9581   
7       chb01_26       0.9383    0.9383    1.0262  49.8370    1.1037   

     M_time  M_instab  
0  270.6371    0.7362  
1  329.3358    0.6878  
2  279.9225    0.7054  
3  253.4804    0.8700  
4  253.1092    0.6918  
5  240.2206    0.7951  
6  258.1867    0.6507  
7  279.9136    0.7667  


Evaluator 1 metrics:
  patient_record  correctness  fidelity  sparsity      ged  M_dissim  \
0            ALL   

In [None]:
# Potrebbe essere utile salvare anche le liste

In [None]:
# import pandas as pd

# patient_record = [
#     "ALL",
#     "chb01_03",
#     "chb01_04",
#     "chb01_15",
#     "chb01_16",
#     "chb01_18",
#     "chb01_21",
#     "chb01_26"
# ]

# # Evaluator 0
# correctness_0 = [0.9102, 0.9877, 0.9661, 0.9167, 0.8288, 0.7745, 0.9582, 0.9383]
# sparsity_0 = [1.0122, 1.1227, 1.0439, 1.0076, 0.9431, 0.8848, 1.0563, 1.0262]
# ged_0 = [49.5060, 55.4632, 51.7433, 48.5735, 46.5658, 42.8554, 51.4472, 49.8370]
# M_dissim_0 = [1.1638, 1.6573, 1.2546, 1.2869, 1.2631, 0.6223, 0.9581, 1.1037]
# M_time_0 = [270.6371, 329.3358, 279.9225, 253.4804, 253.1092, 240.2206, 258.1867, 279.9136]
# M_instab_0 = [0.7362, 0.6878, 0.7054, 0.8700, 0.6918, 0.7951, 0.6507, 0.7667]

# # Evaluator 1
# correctness_1 = [0.9102, 0.9877, 0.9661, 0.9167, 0.8288, 0.7745, 0.9582, 0.9383]
# sparsity_1 = [1.0145, 1.1387, 1.0458, 1.0087, 0.9604, 0.8901, 1.0570, 0.9999]
# ged_1 = [49.6248, 56.2377, 51.8450, 48.6324, 47.4467, 43.1152, 51.4791, 48.5605]
# M_dissim_1 = [1.2083, 1.6976, 1.2964, 1.3393, 1.3504, 0.6366, 1.0067, 1.1310]
# M_time_1 = [223.9748, 300.4461, 249.5738, 212.1520, 194.2159, 180.8799, 209.5135, 220.3012]
# M_instab_1 = [0.5361, 0.3936, 0.3252, 0.8671, 0.2503, 0.7947, 0.4355, 0.7237]

# # Evaluator 2
# correctness_2 = [0.9102, 0.9877, 0.9661, 0.9167, 0.8288, 0.7745, 0.9582, 0.9383]
# sparsity_2 = [1.0140, 1.1222, 1.0414, 1.0063, 0.9561, 0.8948, 1.0631, 1.0130]
# ged_2 = [49.5943, 55.4436, 51.6126, 48.5172, 47.2208, 43.3431, 51.7715, 49.2000]
# M_dissim_2 = [1.1880, 1.6820, 1.2743, 1.3178, 1.3073, 0.6299, 0.9833, 1.1207]
# M_time_2 = [230.9660, 296.0098, 249.6392, 221.0980, 202.2060, 196.6152, 219.8993, 230.6840]
# M_instab_2 = [0.6509, 0.5598, 0.5404, 0.8787, 0.4866, 0.8145, 0.5473, 0.7549]

# # Costruzione dei dataframe
# dfs = {
#     0: pd.DataFrame({
#         "patient_record": patient_record,
#         "correctness": correctness_0,
#         "fidelity": correctness_0,
#         "sparsity": sparsity_0,
#         "ged": ged_0,
#         "M_dissim": M_dissim_0,
#         "M_time": M_time_0,
#         "M_instab": M_instab_0,
#     }),
#     1: pd.DataFrame({
#         "patient_record": patient_record,
#         "correctness": correctness_1,
#         "fidelity": correctness_1,
#         "sparsity": sparsity_1,
#         "ged": ged_1,
#         "M_dissim": M_dissim_1,
#         "M_time": M_time_1,
#         "M_instab": M_instab_1,
#     }),
#     2: pd.DataFrame({
#         "patient_record": patient_record,
#         "correctness": correctness_2,
#         "fidelity": correctness_2,
#         "sparsity": sparsity_2,
#         "ged": ged_2,
#         "M_dissim": M_dissim_2,
#         "M_time": M_time_2,
#         "M_instab": M_instab_2,
#     }),
# }