In [1]:
import inspect
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from utils_martina.my_utils import *
from src.dataset.instances.graph import GraphInstance
from src.evaluation.evaluation_metric_ged import GraphEditDistanceMetric
from src.evaluation.evaluation_metric_correctness import CorrectnessMetric
from src.evaluation.evaluation_metric_fidelity import FidelityMetric
from src.utils.metrics.sparsity import SparsityMetric

#### Load data

In [2]:
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"

file_name = get_most_recent_file(eval_manager_path).split('.')[0]
print(file_name)

with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

3788-Martina


#### Get oracle metrics

In [3]:
def get_oracle_metrics(eval_manager):
    instances = eval_manager.evaluators[0].dataset.instances
    oracle = eval_manager.evaluators[0]._oracle

    grouped = {}
    for inst in instances:
        key = f"{inst.patient_id}_{inst.record_id}"
        grouped.setdefault(key, []).append(inst)

    rows = []
    for pr, inst_list in grouped.items():
        y_true = [i.label for i in inst_list]
        y_pred = [oracle.predict(i) for i in inst_list]
        rows.append({
            "patient_record": pr,
            "accuracy": round(accuracy_score(y_true, y_pred),4),
            "f1_score": round(f1_score(y_true, y_pred),4),
            "recall": round(recall_score(y_true, y_pred),4),
            "precision": round(precision_score(y_true, y_pred),4)
        })

    y_true_all = [i.label for i in instances]
    y_pred_all = [oracle.predict(i) for i in instances]
    global_row = {
        "patient_record": "ALL",
        "accuracy": round(accuracy_score(y_true_all, y_pred_all),4),
        "f1_score": round(f1_score(y_true_all, y_pred_all),4),
        "recall": round(recall_score(y_true_all, y_pred_all),4),
        "precision": round(precision_score(y_true_all, y_pred_all),4)
    }

    return pd.DataFrame([global_row] + rows)

oracle_metrics = get_oracle_metrics(eval_manager)
print(oracle_metrics)

  patient_record  accuracy  f1_score  recall  precision
0            ALL    0.9040    0.9088  0.9102     0.9074
1       chb01_03    0.9279    0.9340  0.9877     0.8857
2       chb01_04    0.9576    0.9591  0.9661     0.9523
3       chb01_15    0.7901    0.8184  0.9167     0.7391
4       chb01_16    0.9041    0.8991  0.8288     0.9824
5       chb01_18    0.8714    0.8646  0.7745     0.9783
6       chb01_21    0.9380    0.9455  0.9582     0.9330
7       chb01_26    0.9411    0.9441  0.9383     0.9500


#### Get top $k$ counterfactuals and explainer metrics

In [4]:
ids = {}
patient_record = {}
correctness = {}
fidelity = {}
sparsity = {}
ged = {}

M_dissim = {}
M_time = {}
M_instab = {}

for i in range(len(eval_manager._evaluators)):
    ids[i] = []
    patient_record[i] = []
    correctness[i] = []
    fidelity[i] = []
    sparsity[i] = []
    ged[i] = []

    M_dissim[i] = []
    M_time[i] = []
    M_instab[i] = []

    cf_dict = {}

    oracle = eval_manager._evaluators[i]._oracle
    explainer = eval_manager._evaluators[i]._explainer

    list_instances = eval_manager._evaluators[i].dataset.instances
    list_instances = [instance for instance in list_instances if (instance.label == 1 and oracle.predict(instance) == 1)]

    num_params = len(inspect.signature(explainer.explain).parameters)

    for instance in tqdm(list_instances, desc=f"Evaluator {i}"):
        # Estraggo metriche solo per istanze corrette
        if num_params == 2:
            result = explainer.explain(instance, return_list=True)
        else:
            result = explainer.explain(instance)
        
        cf_dict[instance.id] = result

        if isinstance(result, GraphInstance):
            counterfactual = result
        else:
            counterfactual = result[0][1]

        ids[i].append(instance.id)
        patient_record[i].append(f"{instance.patient_id}_{instance.record_id}")
        correctness[i].append(CorrectnessMetric().evaluate(instance, counterfactual, oracle, explainer))
        fidelity[i].append(FidelityMetric().evaluate(instance, counterfactual, oracle, explainer))
        sparsity[i].append(SparsityMetric().evaluate(instance, counterfactual))
        ged[i].append(GraphEditDistanceMetric().evaluate(instance, counterfactual))

        if num_params == 2:
            m_dissim, m_time, m_instab = explainer.compute_metric_components(instance, counterfactual)
            M_dissim[i].append(m_dissim)
            M_time[i].append(m_time)
            M_instab[i].append(m_instab)
        else:
            M_dissim[i].append(np.nan)
            M_time[i].append(np.nan)
            M_instab[i].append(np.nan)

    with open(f"output\cf_dict\cf_dict_{file_name}_{i}.pkl", 'wb') as f:
        pickle.dump(cf_dict, f)

Evaluator 0: 100%|██████████| 2596/2596 [1:50:25<00:00,  2.55s/it]  


#### Summary over different explainers, only top counterfactual

In [15]:
dfs = {}

for i in ids.keys():
    dfs[i] = pd.DataFrame({
        "patient_record": patient_record[i],
        "correctness": correctness[i],
        "fidelity": fidelity[i],
        "sparsity": sparsity[i],
        "ged": ged[i],
        "M_dissim": M_dissim[i],
        "M_time": np.sqrt(M_time[i]),
        "M_instab": M_instab[i]
    })

with open(f"output\dfs\dfs_{file_name.split('-')[0]}.pkl", "wb") as f:
    pickle.dump(dfs, f)

In [26]:
with open(f"output\dfs\dfs_{file_name.split('-')[0]}.pkl", "rb") as f:
    dfs = pickle.load(f)

for i, df in dfs.items():
    global_row = df.drop(columns="patient_record").mean().round(4)
    global_row["patient_record"] = "ALL"
    grouped = df.groupby("patient_record").mean().round(4).reset_index()

    explainer_metrics = pd.concat([pd.DataFrame([global_row]), grouped], ignore_index=True)
    cols = ["patient_record", "correctness", "fidelity", "M_dissim"]
    explainer_metrics = explainer_metrics[cols]

    print(f"Evaluator {i} metrics:")
    print(explainer_metrics)
    print("\n")

# ADD PLAUSIBILITY!!!

Evaluator 0 metrics:
  patient_record  correctness  fidelity  M_dissim
0            ALL          1.0       1.0    1.2762
1       chb01_03          1.0       1.0    1.6761
2       chb01_04          1.0       1.0    1.2977
3       chb01_15          1.0       1.0    1.3996
4       chb01_16          1.0       1.0    1.5184
5       chb01_18          1.0       1.0    0.8033
6       chb01_21          1.0       1.0    0.9974
7       chb01_26          1.0       1.0    1.1744




In [None]:
"""for i, df in dfs.items():
    def format_mean_std(row):
        return f"{row['mean']:.4f} ± {row['std']:.4f}"
    
    # Calcolo globale
    global_stats = df.drop(columns="patient_record").agg(['mean', 'std']).T
    global_stats['formatted'] = global_stats.apply(format_mean_std, axis=1)
    global_row = global_stats['formatted'].to_frame().T
    global_row['patient_record'] = 'ALL'
    
    # Calcolo per paziente
    grouped_stats = df.groupby('patient_record').agg(['mean', 'std'])
    
    formatted_cols = pd.DataFrame()
    for col in grouped_stats.columns.levels[0]:
        formatted_cols[col] = grouped_stats[col].apply(format_mean_std, axis=1)
    
    formatted_cols['patient_record'] = grouped_stats.index
    
    # Unisco globale e raggruppato
    explainer_metrics = pd.concat([global_row, formatted_cols], ignore_index=True)
    explainer_metrics = explainer_metrics[["patient_record", "correctness", "fidelity", "M_dissim", "M_time", "M_instab"]]
    explainer_metrics = explainer_metrics.rename(columns={"M_time": "sqrt(M_time)"})
    
    print(f"Evaluator {i} metrics:")
    print(explainer_metrics)
    print("\n")
"""

Evaluator 0 metrics:
  patient_record      correctness         fidelity         M_dissim  \
0            ALL  1.0000 ± 0.0000  1.0000 ± 0.0000  1.2762 ± 0.4641   
1       chb01_03  1.0000 ± 0.0000  1.0000 ± 0.0000  1.6761 ± 0.5261   
2       chb01_04  1.0000 ± 0.0000  1.0000 ± 0.0000  1.2977 ± 0.4933   
3       chb01_15  1.0000 ± 0.0000  1.0000 ± 0.0000  1.3996 ± 0.3442   
4       chb01_16  1.0000 ± 0.0000  1.0000 ± 0.0000  1.5184 ± 0.3811   
5       chb01_18  1.0000 ± 0.0000  1.0000 ± 0.0000  0.8033 ± 0.2419   
6       chb01_21  1.0000 ± 0.0000  1.0000 ± 0.0000  0.9974 ± 0.2123   
7       chb01_26  1.0000 ± 0.0000  1.0000 ± 0.0000  1.1744 ± 0.2693   

          sqrt(M_time)         M_instab  
0  280.8730 ± 178.1943  0.7316 ± 0.2607  
1  356.1451 ± 154.9936  0.6732 ± 0.2884  
2  241.4067 ± 147.8887  0.7085 ± 0.3068  
3  226.6177 ± 153.1784  0.8579 ± 0.1551  
4  341.7298 ± 192.6924  0.7107 ± 0.2755  
5  371.4191 ± 184.7661  0.7797 ± 0.2535  
6   141.4171 ± 80.2202  0.6599 ± 0.2400  
7  

#### Ablation study (single patient)

In [None]:
patient_id = "chb01"
record_id = "03"

M_dissim_mean = []
M_time_mean = []
M_instab_mean = []

M_dissim_std = []
M_time_std = []
M_instab_std = []

for instance in list_instances:
    if instance.patient_id == patient_id and instance.record_id == record_id:
        M_dissim_counterfactuals = []
        M_time_counterfactuals = []
        M_instab_counterfactuals = []

        for counterfactual in cf_dict[instance.id]:
            m_dissim, m_time, m_instab = explainer.compute_metric_components(instance, counterfactual)
            M_dissim_counterfactuals.append(m_dissim)
            M_time_counterfactuals.append(m_time)
            M_instab_counterfactuals.append(m_instab)
        
        M_dissim_mean.append(np.mean(M_dissim_counterfactuals))
        M_dissim_std.append(np.std(M_dissim_counterfactuals))
        M_time_mean.append(np.mean(M_time_counterfactuals))
        M_time_std.append(np.std(M_time_counterfactuals))
        M_instab_mean.append(np.mean(M_instab_counterfactuals))
        M_instab_std.append(np.std(M_instab_counterfactuals))

In [None]:
patient_id = "chb01"
record_id = "03"

print(f"{patient_id}_{record_id}")

for i in range(len(eval_manager._evaluators)):
    explainer = eval_manager._evaluators[i]._explainer

    M_dissim_mean = []
    M_time_mean = []
    M_instab_mean = []

    M_dissim_std = []
    M_time_std = []
    M_instab_std = []

    for instance in list_instances:
        if instance.patient_id == patient_id and instance.record_id == record_id:
            M_dissim_counterfactuals = []
            M_time_counterfactuals = []
            M_instab_counterfactuals = []

            for counterfactual in cf_dict[instance.id]:
                m_dissim, m_time, m_instab = explainer.compute_metric_components(instance, counterfactual[1])
                M_dissim_counterfactuals.append(m_dissim)
                M_time_counterfactuals.append(np.sqrt(m_time))
                M_instab_counterfactuals.append(m_instab)

            M_dissim_mean.append(np.mean(M_dissim_counterfactuals))
            M_dissim_std.append(np.std(M_dissim_counterfactuals))
            M_time_mean.append(np.mean(M_time_counterfactuals))
            M_time_std.append(np.std(M_time_counterfactuals))
            M_instab_mean.append(np.mean(M_instab_counterfactuals))
            M_instab_std.append(np.std(M_instab_counterfactuals))

    # Stampa aggregata per questo evaluator
    print(f"\nEvaluator {i}:")
    print(f"M_dissim:       {np.mean(M_dissim_mean):.4f} ± {np.mean(M_dissim_std):.4f}")
    print(f"sqrt(M_time): {np.mean(M_time_mean):.4f} ± {np.mean(M_time_std):.4f}")
    print(f"M_instab:       {np.mean(M_instab_mean):.4f} ± {np.mean(M_instab_std):.4f}")
    print("")

# DA CONTROLLARE!

chb01_03

Evaluator 0:
M_dissim:       1.7066 ± 0.0266
sqrt(M_time): 360.8770 ± 123.9334
M_instab:       0.7076 ± 0.2307


## _______________________________________________________________