# Adding a new Evaluation Metric to GRETEL

In [7]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [8]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = module_path + '/examples/config/config_autism_custom-oracle_dce_validity.json'
output_file_path = module_path + '/output/asd_custom_oracle/DCESearchExplainer/results_run-0.json'

## Creating a new Evaluation Metric

In [9]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.dataset.data_instance_base import DataInstance
from src.oracle.oracle_base import Oracle


class ValidityMetric(EvaluationMetric):
    """Verifies that the class from the counterfactual example 
    is different from that of the original instance
    """

    def __init__(self, config_dict=None) -> None:
        super().__init__(config_dict)
        self._name = 'Validity'

    def evaluate(self, instance_1: DataInstance, instance_2: DataInstance, oracle: Oracle):

        label_instance_1 = oracle.predict(instance_1)
        label_instance_2 = oracle.predict(instance_2)
        oracle._call_counter -= 2

        result = 1 if (label_instance_1 != label_instance_2) else 0
        
        return result

## Creating a custom Evaluation Metric Factory

In [10]:
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory


class CustomEvaluationMetricFactory(EvaluationMetricFactory):

    def __init__(self) -> None:
        super().__init__()

    def get_evaluation_metric_by_name(self, metric_dict) -> EvaluationMetric:
        metric_name = metric_dict['name']
        metric_parameters = metric_dict['parameters']

        if(metric_name == 'validity'):
            return self.get_validity_metric(config_dict=metric_dict)

        else:
            return super().get_evaluation_metric_by_name(metric_dict)


    def get_validity_metric(self, config_dict=None) -> EvaluationMetric:
        result = ValidityMetric(config_dict)
        return result


## Using the new Evaluation Metric

In [11]:
em_factory = CustomEvaluationMetricFactory()

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=em_factory)
eval_manager.create_evaluators()
eval_manager.evaluate()

## Checking the results

In [12]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "validity", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.16681551933288574, 0.07245635986328125, 0.06671309471130371, 0.06960177421569824, 0.07045936584472656, 0.0999002456665039, 0.13517308235168457, 0.07583427429199219, 0.06962323188781738, 0.1026611328125, 0.0667886734008789, 0.07045817375183105, 0.09930682182312012, 0.08687901496887207, 0.1194307804107666, 0.09802508354187012, 0.06516861915588379, 0.12294268608093262, 0.0696115493774414, 0.0756065845489502, 0.0829010009765625, 0.09508204460144043, 0.123664140701293