# Adding a new Evaluation Metric to GRETEL

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [2]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = module_path + '/examples/config/config_autism_custom-oracle_dce_validity.json'
output_file_path = module_path + '/examples/output_new_metric/asd_custom_oracle/DCESearchExplainer/results_run-0.json'
output_folder = module_path + '/examples/output_new_metric/'
stats_folder = module_path + '/examples/stats_new_metric/'

2023-02-09 22:17:03.412468: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-09 22:17:03.843599: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


## Creating a new Evaluation Metric

In [3]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.dataset.data_instance_base import DataInstance
from src.oracle.oracle_base import Oracle


class ValidityMetric(EvaluationMetric):
    """Verifies that the class from the counterfactual example 
    is different from that of the original instance
    """

    def __init__(self, config_dict=None) -> None:
        super().__init__(config_dict)
        self._name = 'Validity'

    def evaluate(self, instance_1: DataInstance, instance_2: DataInstance, oracle: Oracle):

        label_instance_1 = oracle.predict(instance_1)
        label_instance_2 = oracle.predict(instance_2)
        oracle._call_counter -= 2

        result = 1 if (label_instance_1 != label_instance_2) else 0
        
        return result

## Creating a custom Evaluation Metric Factory

In [4]:
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory


class CustomEvaluationMetricFactory(EvaluationMetricFactory):

    def __init__(self) -> None:
        super().__init__()

    def get_evaluation_metric_by_name(self, metric_dict) -> EvaluationMetric:
        metric_name = metric_dict['name']
        metric_parameters = metric_dict['parameters']

        if(metric_name == 'validity'):
            return self.get_validity_metric(config_dict=metric_dict)

        else:
            return super().get_evaluation_metric_by_name(metric_dict)


    def get_validity_metric(self, config_dict=None) -> EvaluationMetric:
        result = ValidityMetric(config_dict)
        return result


## Using the new Evaluation Metric

In [5]:
em_factory = CustomEvaluationMetricFactory()

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=em_factory)
eval_manager.create_evaluators()
eval_manager.evaluate()

## Checking the results

In [6]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "validity", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.1267530918121338, 0.06608009338378906, 0.07161259651184082, 0.08130764961242676, 0.0654757022857666, 0.0670173168182373, 0.06748080253601074, 0.06579017639160156, 0.0644073486328125, 0.06924843788146973, 0.07727622985839844, 0.06996798515319824, 0.07122993469238281, 0.06432032585144043, 0.06716775894165039, 0.08757734298706055, 0.06232643127441406, 0.06529378890991211, 0.07831525802612305, 0.0719597339630127, 0.0836949348449707, 0.07471179962158203, 0.07313728332

In [7]:
from src.data_analysis.data_analyzer import DataAnalyzer
import pandas as pd

dtan = DataAnalyzer(output_folder, stats_folder)
dtan.aggregate_data()
dtan.aggregate_runs()
dtan.create_tables_by_oracle_dataset()

results_table = pd.read_csv(stats_folder + 'autism-asd_custom_oracle.csv')
results_table

Unnamed: 0.1,Unnamed: 0,explainer,runtime,runtime-std,Graph_Edit_Distance,Graph_Edit_Distance-std,Oracle_Calls,Oracle_Calls-std,Validity,Validity-std,Sparsity,Sparsity-std,Fidelity,Fidelity-std,Oracle_Accuracy,Oracle_Accuracy-std
0,0,dce_search,0.070959,0.0,1011.693069,0.0,102.0,0.0,1.0,0.0,1.311108,0.0,0.544554,0.0,0.772277,0.0
