In [None]:
import inspect
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from utils.utils import *
from src.dataset.instances.graph import GraphInstance
from src.evaluation.evaluation_metric_runtime import RuntimeMetric
from src.evaluation.evaluation_metric_correctness import CorrectnessMetric
from src.evaluation.evaluation_metric_fidelity import FidelityMetric
from src.evaluation.evaluation_metric_implausibility import ImplausibilityMetric
from src.evaluation.evaluation_metric_dissimilarity import M_dissim_metric

#### Load data

In [2]:
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"

file_name = get_most_recent_file(eval_manager_path).split('.')[0]
file_name = "19708-Martina"
print(file_name)

with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

19708-Martina


#### Oracle metrics

In [3]:
def get_oracle_metrics(eval_manager):
    instances = eval_manager.evaluators[0].dataset.instances
    oracle = eval_manager.evaluators[0]._oracle

    grouped = {}
    for inst in instances:
        key = f"{inst.patient_id}_{inst.record_id}"
        grouped.setdefault(key, []).append(inst)

    rows = []
    for pr, inst_list in grouped.items():
        y_true = [i.label for i in inst_list]
        y_pred = [oracle.predict(i) for i in inst_list]
        rows.append({
            "patient_record": pr,
            "accuracy": round(accuracy_score(y_true, y_pred),4),
            "f1_score": round(f1_score(y_true, y_pred),4),
            "recall": round(recall_score(y_true, y_pred),4),
            "precision": round(precision_score(y_true, y_pred),4)
        })

    y_true_all = [i.label for i in instances]
    y_pred_all = [oracle.predict(i) for i in instances]
    global_row = {
        "patient_record": "ALL",
        "accuracy": round(accuracy_score(y_true_all, y_pred_all),4),
        "f1_score": round(f1_score(y_true_all, y_pred_all),4),
        "recall": round(recall_score(y_true_all, y_pred_all),4),
        "precision": round(precision_score(y_true_all, y_pred_all),4)
    }

    return pd.DataFrame([global_row] + rows)

oracle_metrics = get_oracle_metrics(eval_manager)
print(oracle_metrics)

  patient_record  accuracy  f1_score  recall  precision
0            ALL    0.9402    0.9434  0.9470     0.9398
1       chb03_01    0.9336    0.9363  0.9124     0.9615
2       chb03_02    0.9550    0.9580  0.9779     0.9388
3       chb03_03    0.9554    0.9593  0.9709     0.9480
4       chb03_04    0.9708    0.9720  0.9732     0.9708
5       chb03_34    0.9052    0.9049  0.8686     0.9444
6       chb03_35    0.9376    0.9431  0.9854     0.9042
7       chb03_36    0.9246    0.9281  0.9407     0.9159


#### Get top $k$ counterfactuals and explainer metrics

In [4]:
# metrics_keys = ['ID', 'Runtime', 'Correctness', 'Fidelity', 'Implausibility', 'Dissimilarity', 'Accuracy']
metrics_keys = ['ID', 'Correctness', 'Fidelity', 'Implausibility', 'Dissimilarity', 'Accuracy']
metrics = {k: {} for k in metrics_keys}

In [5]:
for i, evaluator in enumerate(eval_manager._evaluators):
    explainer = evaluator._explainer
    name = explainer.name.split('-')[0]
    if "Temporal" in name:
        if 'NoStability' not in name:
            print(f"({i}) {name}: ({explainer.alpha}, {explainer.beta}, {explainer.gamma})")
        else:
            print(f"({i}) {name}: ({explainer.alpha}, {explainer.beta}, X)")
    else:
        print(f"({i}) {name}")

(0) DataDrivenBidirectionalSearchExplainer
(1) ObliviousBidirectionalSearchExplainer
(2) TemporalDCESExplainer: (0.7, 0.2, 0.1)
(3) TemporalDCESExplainerNoStability: (1, 0, X)
(4) GNNMOExp


In [6]:
indices = list(range(len(eval_manager._evaluators)))

In [7]:
correctness_metric = CorrectnessMetric()
fidelity_metric = FidelityMetric()
implausibility_metric = ImplausibilityMetric()
dissimilarity_metric = M_dissim_metric()
# runtime_metric = RuntimeMetric()

all_cf_dicts = {}

for i in indices:
    # Initialize empty lists for each metric for this evaluator
    for key in metrics_keys:
        metrics[key][i] = []

    cf_dict = {}

    evaluator = eval_manager._evaluators[i]
    oracle = evaluator._oracle
    explainer = evaluator._explainer
    dataset = evaluator.dataset.instances

    # Select only correctly predicted positive instances
    list_instances = [
        inst for inst in dataset
        if inst.label == 1
    ]

    # list_instances = list_instances[:10]

    # Check if the explainer's explain method takes two arguments (i.e., has return_list=True)
    has_return_list = len(inspect.signature(explainer.explain).parameters) == 2

    for instance in tqdm(list_instances, desc=f"Evaluator {i}"):
        # Generate counterfactual for the instance
        result = explainer.explain(instance, return_list=True) if has_return_list else explainer.explain(instance)
        cf_dict[instance.id] = result

        # Extract the counterfactual instance
        counterfactual = result if isinstance(result, GraphInstance) else result[0][1]

        # Record info and metrics
        metrics['ID'][i].append(instance.id)
        # metrics['Runtime'][i].append(runtime_metric.evaluate(instance, counterfactual, oracle, explainer)[0])
        metrics['Correctness'][i].append(correctness_metric.evaluate(instance, counterfactual, oracle, explainer))
        metrics['Fidelity'][i].append(fidelity_metric.evaluate(instance, counterfactual, oracle, explainer))
        metrics['Implausibility'][i].append(implausibility_metric.evaluate(instance, counterfactual, dataset=dataset))
        metrics['Dissimilarity'][i].append(dissimilarity_metric.evaluate(instance, counterfactual, oracle, explainer))
        metrics['Accuracy'][i].append(oracle.predict(instance) == 1)

    all_cf_dicts[i] = cf_dict

with open(f"output/cf_dict/cf_dict_{file_name}.pkl", 'wb') as f:
    pickle.dump(all_cf_dicts, f)

with open(f"output/cf_dict/metrics_{file_name}.pkl", 'wb') as f:
    pickle.dump(metrics, f)

Evaluator 0: 100%|██████████| 2870/2870 [7:14:34<00:00,  9.09s/it]  
Evaluator 1: 100%|██████████| 2870/2870 [3:32:55<00:00,  4.45s/it]  
Evaluator 2: 100%|██████████| 2870/2870 [1:15:50<00:00,  1.59s/it]
Evaluator 3: 100%|██████████| 2870/2870 [3:35:40<00:00,  4.51s/it]  
Evaluator 4: 100%|██████████| 2870/2870 [04:37<00:00, 10.33it/s]


In [8]:
with open(f"output/cf_dict/cf_dict_{file_name}.pkl", 'wb') as f:
    pickle.dump(all_cf_dicts, f)

with open(f"output/cf_dict/metrics_{file_name}.pkl", 'wb') as f:
    pickle.dump(metrics, f)

In [9]:
# Se serve, la runtime la faccio a posteriori in futuro

#### Explainer metrics (all patients, only top counterfactual)

In [10]:
rows = []

with open(f"output/cf_dict/metrics_{file_name}.pkl", 'rb') as f:
    metrics = pickle.load(f)

for eval_id in metrics['ID'].keys():
    row = {"Evaluator": eval_id}
    acc = np.array(metrics['Accuracy'][eval_id])

    for metric in metrics_keys[1:]:
        vals = np.array(metrics[metric][eval_id])

        # Implausibility and Dissimilarity are only computed for well-predicted instances
        if metric in ['Implausibility', 'Dissimilarity']:
            filtered_vals = vals[acc == True]
            mean_val = np.mean(filtered_vals) if filtered_vals.size > 0 else np.nan
        else:
            mean_val = np.mean(vals) if vals.size > 0 else np.nan
        row[metric] = round(mean_val, 4) if not np.isnan(mean_val) else np.nan
        
    rows.append(row)

summary_df = pd.DataFrame(rows).set_index("Evaluator")
print(summary_df.iloc[:,:-1])

           Correctness  Fidelity  Implausibility  Dissimilarity
Evaluator                                                      
0               0.1160    0.0854          0.2392         0.2392
1               0.2819    0.2226          0.1547         0.1547
2               0.9470    0.9470          0.0000         0.9730
3               0.9997    0.8941          0.0000         0.9486
4               0.1889    0.1087             NaN            NaN


#### Ablation study (single patient)

In [11]:
patient_id = "chb03"
record_id = "01"

# Carica tutti i cf_dict da file unico
with open(f"output/cf_dict/cf_dict_{file_name}.pkl", "rb") as f:
    all_cf_dicts = pickle.load(f)

# Crea dizionario di set con gli ID corretti per ogni evaluator
accurate_ids = {
    i: set(
        id_ for id_, acc in zip(metrics['ID'][i], metrics['Accuracy'][i]) if acc
    )
    for i in metrics['ID'].keys()
}

rows = []

for i, evaluator in enumerate(eval_manager._evaluators):
    explainer = evaluator._explainer

    if 'Temporal' not in explainer.name:
        continue

    # Prendi il cf_dict corrispondente a questo evaluator
    cf_dict = all_cf_dicts[i]

    m_d, m_t, m_i = [], [], []

    for instance in evaluator.dataset.instances:
        if instance.patient_id != patient_id or instance.record_id != record_id:
            continue

        if instance.id not in accurate_ids.get(i, set()):
            continue

        d_vals, t_vals, i_vals = [], [], []

        for cf in cf_dict[instance.id]:
            d, t, instab = explainer.compute_metric_components(instance, cf[1])
            d_vals.append(d)
            t_vals.append(np.sqrt(t))  # sqrt del tempo
            i_vals.append(instab)

        if d_vals:
            m_d.append((np.mean(d_vals), np.std(d_vals)))
            m_t.append((np.mean(t_vals), np.std(t_vals)))
            m_i.append((np.mean(i_vals), np.std(i_vals)))

    def format_metric(pairs):
        if not pairs:
            return "nan ± nan"
        mean = np.mean([m for m, _ in pairs])
        std = np.mean([s for _, s in pairs])
        return f"{mean:.4f} ± {std:.4f}"
    
    if 'NoStability' not in explainer.name:
        parameters = f"({explainer.alpha}, {explainer.beta}, {explainer.gamma})"
    else:
        parameters = f"({explainer.alpha}, {explainer.beta}, X)"

    row = {
        "Parameters": parameters,
        "M_dissim": format_metric(m_d),
        "sqrt(M_time)": format_metric(m_t),
        "M_instab": format_metric(m_i)
    }
    rows.append(row)

if rows:
    summary_df = pd.DataFrame(rows)
    print(f"Record: {patient_id}_{record_id}\n")
    print(summary_df)

Record: chb03_01

        Parameters         M_dissim        sqrt(M_time)         M_instab
0  (0.7, 0.2, 0.1)  0.4238 ± 0.0084   81.8231 ± 32.3935  0.4868 ± 0.2077
1        (1, 0, X)  0.4015 ± 0.0112  171.3438 ± 86.3431        nan ± nan


## _______________________________________________________________

In [12]:
import datetime
now = datetime.datetime.now()
print("Last full run:", now.strftime("%d/%m/%y, hour %H:%M"))

Last full run: 12/07/25, hour 14:44
