# Adding a new Evaluation Metric to GRETEL

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [2]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = '/NFSHOME/mprado/CODE/GRETEL/examples/config/config_autism_custom-oracle_dce_validity.json'
output_file_path = '/NFSHOME/mprado/CODE/GRETEL/output/asd_custom_oracle/DCESearchExplainer/results_run-0.json'

2022-06-15 11:32:38.503438: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/gridengine/lib/lx-amd64:/opt/openmpi/lib
2022-06-15 11:32:38.503480: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


## Creating a new Evaluation Metric

In [3]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.dataset.data_instance_base import DataInstance
from src.oracle.oracle_base import Oracle


class ValidityMetric(EvaluationMetric):
    """Verifies that the class from the counterfactual example 
    is different from that of the original instance
    """

    def __init__(self, config_dict=None) -> None:
        super().__init__(config_dict)
        self._name = 'Validity'

    def evaluate(self, instance_1: DataInstance, instance_2: DataInstance, oracle: Oracle):

        label_instance_1 = oracle.predict(instance_1)
        label_instance_2 = oracle.predict(instance_2)
        oracle._call_counter -= 2

        result = 1 if (label_instance_1 != label_instance_2) else 0
        
        return result

## Creating a custom Evaluation Metic Factory

In [4]:
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory


class CustomEvaluationMetricFactory(EvaluationMetricFactory):

    def __init__(self) -> None:
        super().__init__()

    def get_evaluation_metric_by_name(self, metric_dict) -> EvaluationMetric:
        metric_name = metric_dict['name']
        metric_parameters = metric_dict['parameters']

        if(metric_name == 'validity'):
            return self.get_validity_metric(config_dict=metric_dict)

        else:
            return super().get_evaluation_metric_by_name(metric_dict)


    def get_validity_metric(self, config_dict=None) -> EvaluationMetric:
        result = ValidityMetric(config_dict)
        return result


## Using the new Evaluation Metric

In [8]:
em_factory = CustomEvaluationMetricFactory()

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=em_factory)
eval_manager.create_evaluators()
eval_manager.evaluate()

## Checking the results

In [9]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "validity", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.2835714817047119, 0.16933059692382812, 0.16750001907348633, 0.16767382621765137, 0.1678917407989502, 0.16853713989257812, 0.150054931640625, 0.16866755485534668, 0.16694045066833496, 0.16832780838012695, 0.15080833435058594, 0.15102505683898926, 0.1688697338104248, 0.16669034957885742, 0.14879298210144043, 0.16776418685913086, 0.16547012329101562, 0.16491937637329102, 0.16670918464660645, 0.16618704795837402, 0.16722607612609863, 0.14999914169311523, 0.1674084663