In [1]:
import pickle
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from utils_martina.my_utils import *
from src.dataset.instances.graph import GraphInstance
from src.evaluation.evaluation_metric_ged import GraphEditDistanceMetric
from src.evaluation.evaluation_metric_correctness import CorrectnessMetric
from src.evaluation.evaluation_metric_fidelity import FidelityMetric
from src.utils.metrics.sparsity import SparsityMetric

### Load data

In [2]:
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"

file_name = get_most_recent_file(eval_manager_path).split('.')[0]
print(file_name)

with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

22268-Martina


### Get oracle metrics

In [3]:
def get_oracle_metrics(eval_manager):
    instances = eval_manager.evaluators[0].dataset.instances
    oracle = eval_manager.evaluators[0]._oracle

    grouped = {}
    for inst in instances:
        key = f"{inst.patient_id}_{inst.record_id}"
        grouped.setdefault(key, []).append(inst)

    rows = []
    for pr, inst_list in grouped.items():
        y_true = [i.label for i in inst_list]
        y_pred = [oracle.predict(i) for i in inst_list]
        rows.append({
            "patient_record": pr,
            "accuracy": round(accuracy_score(y_true, y_pred),4),
            "f1_score": round(f1_score(y_true, y_pred),4),
            "recall": round(recall_score(y_true, y_pred),4),
            "precision": round(precision_score(y_true, y_pred),4)
        })

    y_true_all = [i.label for i in instances]
    y_pred_all = [oracle.predict(i) for i in instances]
    global_row = {
        "patient_record": "ALL",
        "accuracy": round(accuracy_score(y_true_all, y_pred_all),4),
        "f1_score": round(f1_score(y_true_all, y_pred_all),4),
        "recall": round(recall_score(y_true_all, y_pred_all),4),
        "precision": round(precision_score(y_true_all, y_pred_all),4)
    }

    return pd.DataFrame([global_row] + rows)

oracle_metrics = get_oracle_metrics(eval_manager)
print(oracle_metrics)

  patient_record  accuracy  f1_score  recall  precision
0            ALL    0.9454    0.9459  0.9257     0.9669
1       chb01_03    0.9583    0.9595  0.9583     0.9607
2       chb01_04    0.9326    0.9318  0.8935     0.9736


### Get top $k$ counterfactuals and explainer metrics

In [4]:
ids = {}
patient_record = {}
correctness = {}
fidelity = {}
sparsity = {}
ged = {}

M_dissim = {}
M_time = {}
M_instab = {}

for i in range(len(eval_manager._evaluators)):
    # inizializzo liste vuote per l'evaluatore i
    ids[i] = []
    patient_record[i] = []
    correctness[i] = []
    fidelity[i] = []
    sparsity[i] = []
    ged[i] = []

    M_dissim[i] = []
    M_time[i] = []
    M_instab[i] = []

    cf_dict = {}

    oracle = eval_manager._evaluators[i]._oracle
    explainer = eval_manager._evaluators[i]._explainer

    list_instances = eval_manager._evaluators[i].dataset.instances
    list_instances = [instance for instance in list_instances if instance.label == 1]

    for instance in tqdm(list_instances, desc=f"Evaluator {i}"):
        lista = explainer.explain(instance, return_list=True)
        cf_dict[instance.id] = lista

        if isinstance(lista, GraphInstance): # Questi sono i casi per cui l'oracolo ha sbagliato
            counterfactual = lista
        else:
            counterfactual = lista[0][1]

        ids[i].append(instance.id)
        patient_record[i].append(f"{instance.patient_id}_{instance.record_id}")
        correctness[i].append(CorrectnessMetric().evaluate(instance, counterfactual, oracle, explainer))
        fidelity[i].append(FidelityMetric().evaluate(instance, counterfactual, oracle, explainer))
        sparsity[i].append(SparsityMetric().evaluate(instance, counterfactual))
        ged[i].append(GraphEditDistanceMetric().evaluate(instance, counterfactual))

        m_dissim, m_time, m_instab = explainer.compute_metric_components(instance, counterfactual)
        M_dissim[i].append(m_dissim)
        M_time[i].append(m_time)
        M_instab[i].append(m_instab)

    with open(f"output\cf_dict\cf_dict_{file_name}_{i}.pkl", 'wb') as f:
        pickle.dump(cf_dict, f)

Evaluator 0: 100%|██████████| 821/821 [48:20<00:00,  3.53s/it]


In [5]:
dfs = {}

for i in ids.keys():
    dfs[i] = pd.DataFrame({
        "patient_record": patient_record[i],
        "correctness": correctness[i],
        "fidelity": fidelity[i],
        "sparsity": sparsity[i],
        "ged": ged[i],
        "M_dissim": M_dissim[i],
        "M_time": M_time[i],
        "M_instab": M_instab[i]
    })

for i, df in dfs.items():
    global_row = df.drop(columns="patient_record").mean().round(4)
    global_row["patient_record"] = "ALL"
    grouped = df.groupby("patient_record").mean().round(4).reset_index()

    explainer_metrics = pd.concat([pd.DataFrame([global_row]), grouped], ignore_index=True)
    cols = ["patient_record", "correctness", "fidelity", "sparsity", "ged"]
    explainer_metrics = explainer_metrics[cols]

    print(f"Evaluator {i} metrics:")
    print(explainer_metrics)
    print("\n")

Evaluator 0 metrics:
  patient_record  correctness  fidelity  sparsity      ged
0            ALL       0.9257    0.9257    1.0216  50.5749
1       chb01_03       0.9583    0.9583    1.0844  53.6152
2       chb01_04       0.8935    0.8935    0.9596  47.5714


