# Adding a new Evaluation Metric to GRETEL

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [2]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = module_path + '/examples/config/config_autism_custom-oracle_dce_validity.json'
output_file_path = module_path + '/output/asd_custom_oracle/DCESearchExplainer/results_run-0.json'

2023-02-09 16:53:48.537675: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-09 16:53:48.695075: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


## Creating a new Evaluation Metric

In [3]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.dataset.data_instance_base import DataInstance
from src.oracle.oracle_base import Oracle


class ValidityMetric(EvaluationMetric):
    """Verifies that the class from the counterfactual example 
    is different from that of the original instance
    """

    def __init__(self, config_dict=None) -> None:
        super().__init__(config_dict)
        self._name = 'Validity'

    def evaluate(self, instance_1: DataInstance, instance_2: DataInstance, oracle: Oracle):

        label_instance_1 = oracle.predict(instance_1)
        label_instance_2 = oracle.predict(instance_2)
        oracle._call_counter -= 2

        result = 1 if (label_instance_1 != label_instance_2) else 0
        
        return result

## Creating a custom Evaluation Metric Factory

In [4]:
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory


class CustomEvaluationMetricFactory(EvaluationMetricFactory):

    def __init__(self) -> None:
        super().__init__()

    def get_evaluation_metric_by_name(self, metric_dict) -> EvaluationMetric:
        metric_name = metric_dict['name']
        metric_parameters = metric_dict['parameters']

        if(metric_name == 'validity'):
            return self.get_validity_metric(config_dict=metric_dict)

        else:
            return super().get_evaluation_metric_by_name(metric_dict)


    def get_validity_metric(self, config_dict=None) -> EvaluationMetric:
        result = ValidityMetric(config_dict)
        return result


## Using the new Evaluation Metric

In [5]:
em_factory = CustomEvaluationMetricFactory()

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=em_factory)
eval_manager.create_evaluators()
eval_manager.evaluate()

## Checking the results

In [6]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "validity", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.11985492706298828, 0.055229902267456055, 0.049341678619384766, 0.056453704833984375, 0.058402299880981445, 0.0653085708618164, 0.06378293037414551, 0.07247304916381836, 0.060153961181640625, 0.06528663635253906, 0.05909085273742676, 0.06982779502868652, 0.06344914436340332, 0.057685136795043945, 0.05154681205749512, 0.052282094955444336, 0.06466245651245117, 0.06275343894958496, 0.06007695198059082, 0.058843135833740234, 0.0713496208190918, 0.05478501319885254, 0