In [1]:
import inspect
import pickle
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from utils_martina.my_utils import *
from src.dataset.instances.graph import GraphInstance
from src.evaluation.evaluation_metric_correctness import CorrectnessMetric
from src.evaluation.evaluation_metric_fidelity import FidelityMetric
from src.evaluation.evaluation_metric_implausibility import ImplausibilityMetric
from src.evaluation.evaluation_metric_dissimilarity import M_dissim_metric

#### Load data

In [2]:
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"

file_name = get_most_recent_file(eval_manager_path).split('.')[0]
print(file_name)

# file_name = "23256-Martina"

with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

3788-Martina


#### Oracle metrics

In [3]:
def get_oracle_metrics(eval_manager):
    instances = eval_manager.evaluators[0].dataset.instances
    oracle = eval_manager.evaluators[0]._oracle

    grouped = {}
    for inst in instances:
        key = f"{inst.patient_id}_{inst.record_id}"
        grouped.setdefault(key, []).append(inst)

    rows = []
    for pr, inst_list in grouped.items():
        y_true = [i.label for i in inst_list]
        y_pred = [oracle.predict(i) for i in inst_list]
        rows.append({
            "patient_record": pr,
            "accuracy": round(accuracy_score(y_true, y_pred),4),
            "f1_score": round(f1_score(y_true, y_pred),4),
            "recall": round(recall_score(y_true, y_pred),4),
            "precision": round(precision_score(y_true, y_pred),4)
        })

    y_true_all = [i.label for i in instances]
    y_pred_all = [oracle.predict(i) for i in instances]
    global_row = {
        "patient_record": "ALL",
        "accuracy": round(accuracy_score(y_true_all, y_pred_all),4),
        "f1_score": round(f1_score(y_true_all, y_pred_all),4),
        "recall": round(recall_score(y_true_all, y_pred_all),4),
        "precision": round(precision_score(y_true_all, y_pred_all),4)
    }

    return pd.DataFrame([global_row] + rows)

oracle_metrics = get_oracle_metrics(eval_manager)
print(oracle_metrics)

  patient_record  accuracy  f1_score  recall  precision
0            ALL    0.9040    0.9088  0.9102     0.9074
1       chb01_03    0.9279    0.9340  0.9877     0.8857
2       chb01_04    0.9576    0.9591  0.9661     0.9523
3       chb01_15    0.7901    0.8184  0.9167     0.7391
4       chb01_16    0.9041    0.8991  0.8288     0.9824
5       chb01_18    0.8714    0.8646  0.7745     0.9783
6       chb01_21    0.9380    0.9455  0.9582     0.9330
7       chb01_26    0.9411    0.9441  0.9383     0.9500


#### Get top $k$ counterfactuals and explainer metrics

In [None]:
metrics_keys = ['ID', 'Correctness', 'Fidelity', 'Implausibility', 'Dissimilarity']
metrics = {k: {} for k in metrics_keys}

correctness_metric = CorrectnessMetric()
fidelity_metric = FidelityMetric()
implausibility_metric = ImplausibilityMetric()
dissimilarity_metric = M_dissim_metric()

for i in range(len(eval_manager._evaluators)):
    # Initialize empty lists for each metric for this evaluator
    for key in metrics_keys:
        metrics[key][i] = []

    cf_dict = {}

    evaluator = eval_manager._evaluators[i]
    oracle = evaluator._oracle
    explainer = evaluator._explainer
    dataset = evaluator.dataset.instances

    # Select only correctly predicted positive instances
    list_instances = [
        inst for inst in dataset
        if inst.label == 1 and oracle.predict(inst) == 1
    ]

    # Check if the explainer's explain method takes two arguments (i.e., has return_list=True)
    has_return_list = len(inspect.signature(explainer.explain).parameters) == 2

    for instance in tqdm(list_instances, desc=f"Evaluator {i}"):
        # Generate counterfactual for the instance
        result = explainer.explain(instance, return_list=True) if has_return_list else explainer.explain(instance)
        cf_dict[instance.id] = result

        # Extract the counterfactual instance
        counterfactual = result if isinstance(result, GraphInstance) else result[0][1]

        # Record info and metrics
        metrics['ID'][i].append(instance.id)
        metrics['Correctness'][i].append(correctness_metric.evaluate(instance, counterfactual, oracle, explainer))
        metrics['Fidelity'][i].append(fidelity_metric.evaluate(instance, counterfactual, oracle, explainer))
        metrics['Implausibility'][i].append(implausibility_metric.evaluate(instance, counterfactual, dataset=dataset))
        metrics['Dissimilarity'][i].append(dissimilarity_metric.evaluate(instance, counterfactual, oracle, explainer))

    with open(f"output/cf_dict/cf_dict_{file_name}_{i}.pkl", 'wb') as f:
        pickle.dump(cf_dict, f)

    with open(f"output/cf_dict/metrics_{file_name}_{i}.pkl", 'wb') as f:
        pickle.dump(metrics, f)

Evaluator 0: 100%|██████████| 10/10 [00:44<00:00,  4.48s/it]


#### Explainer metrics (all patients, only top counterfactual)

In [5]:
rows = []

for eval_id in metrics['ID'].keys():
    row = {"Evaluator": eval_id}
    for metric in metrics_keys[1:]:
        row[metric] = np.mean(metrics[metric][eval_id]).round(4)
    rows.append(row)

summary_df = pd.DataFrame(rows)
summary_df.set_index("Evaluator", inplace=True)
print(summary_df)

           Correctness  Fidelity  Implausibility  Dissimilarity
Evaluator                                                      
0                  1.0       1.0             0.0         0.7918


#### Ablation study (single patient)

In [6]:
patient_id = "chb01"
record_id = "03"

rows = []

for i, evaluator in enumerate(eval_manager._evaluators):
    explainer = evaluator._explainer

    if not 'Temporal' in explainer.name:
        continue

    m_d, m_t, m_i = [], [], []

    for instance in list_instances:
        if instance.patient_id != patient_id or instance.record_id != record_id:
            continue

        d_vals, t_vals, i_vals = [], [], []

        for cf in cf_dict[instance.id]:
            d, t, instab = explainer.compute_metric_components(instance, cf[1])
            d_vals.append(d)
            t_vals.append(np.sqrt(t))  # Apply sqrt here
            i_vals.append(instab)

        if d_vals:
            m_d.append((np.mean(d_vals), np.std(d_vals)))
            m_t.append((np.mean(t_vals), np.std(t_vals)))
            m_i.append((np.mean(i_vals), np.std(i_vals)))

    # Aggregate across instances (per evaluator)
    def format_metric(pairs):
        if not pairs:
            return "nan ± nan"
        mean = np.mean([m for m, _ in pairs])
        std = np.mean([s for _, s in pairs])
        return f"{mean:.4f} ± {std:.4f}"

    row = {
        "Evaluator": i,
        "M_dissim": format_metric(m_d),
        "sqrt(M_time)": format_metric(m_t),
        "M_instab": format_metric(m_i)
    }
    rows.append(row)

if rows != []:
    summary_df = pd.DataFrame(rows).set_index("Evaluator")
    print(f"Record: {patient_id}_{record_id}\n")
    print(summary_df)

Record: chb01_03

                  M_dissim         sqrt(M_time)         M_instab
Evaluator                                                       
0          0.8071 ± 0.0203  316.0506 ± 124.5942  0.6784 ± 0.2537


## _______________________________________________________________

In [7]:
import datetime
now = datetime.datetime.now()
print("Last full run:", now.strftime("%d/%m/%y, hour %H:%M"))

Last full run: 03/07/25, hour 10:59
