# Adding a new Evaluation Metric to GRETEL

In [8]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [9]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = module_path + '/examples/config/config_autism_custom-oracle_dce_validity.json'
output_file_path = module_path + '/examples/output_new_metric/asd_custom_oracle/DCESearchExplainer/results_run-0.json'
output_folder = module_path + '/examples/output_new_metric/'
stats_folder = module_path + '/examples/stats_new_metric/'

## Creating a new Evaluation Metric

In [10]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.dataset.data_instance_base import DataInstance
from src.oracle.oracle_base import Oracle


class ValidityMetric(EvaluationMetric):
    """Verifies that the class from the counterfactual example 
    is different from that of the original instance
    """

    def __init__(self, config_dict=None) -> None:
        super().__init__(config_dict)
        self._name = 'Validity'

    def evaluate(self, instance_1: DataInstance, instance_2: DataInstance, oracle: Oracle):

        label_instance_1 = oracle.predict(instance_1)
        label_instance_2 = oracle.predict(instance_2)
        oracle._call_counter -= 2

        result = 1 if (label_instance_1 != label_instance_2) else 0
        
        return result

## Creating a custom Evaluation Metric Factory

In [11]:
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory


class CustomEvaluationMetricFactory(EvaluationMetricFactory):

    def __init__(self) -> None:
        super().__init__()

    def get_evaluation_metric_by_name(self, metric_dict) -> EvaluationMetric:
        metric_name = metric_dict['name']
        metric_parameters = metric_dict['parameters']

        if(metric_name == 'validity'):
            return self.get_validity_metric(config_dict=metric_dict)

        else:
            return super().get_evaluation_metric_by_name(metric_dict)


    def get_validity_metric(self, config_dict=None) -> EvaluationMetric:
        result = ValidityMetric(config_dict)
        return result


## Using the new Evaluation Metric

In [12]:
em_factory = CustomEvaluationMetricFactory()

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=em_factory)
eval_manager.create_evaluators()
eval_manager.evaluate()

## Checking the results

In [13]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "validity", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.14621305465698242, 0.08006906509399414, 0.0641489028930664, 0.07153201103210449, 0.07117462158203125, 0.10097837448120117, 0.07581448554992676, 0.09013652801513672, 0.0803062915802002, 0.07727694511413574, 0.0703282356262207, 0.0743856430053711, 0.0836799144744873, 0.10117864608764648, 0.07774209976196289, 0.06998872756958008, 0.06997513771057129, 0.07279682159423828, 0.0685722827911377, 0.07396316528320312, 0.08022737503051758, 0.07196712493896484, 0.07120800018

In [14]:
from src.data_analysis.data_analyzer import DataAnalyzer
import pandas as pd

dtan = DataAnalyzer(output_folder, stats_folder)
dtan.aggregate_data()
dtan.aggregate_runs()
dtan.create_tables_by_oracle_dataset()

results_table = pd.read_csv(stats_folder + 'autism-asd_custom_oracle.csv')
results_table

Unnamed: 0.1,Unnamed: 0,explainer,runtime,runtime-std,Graph_Edit_Distance,Graph_Edit_Distance-std,Oracle_Calls,Oracle_Calls-std,Validity,Validity-std,Sparsity,Sparsity-std,Fidelity,Fidelity-std,Oracle_Accuracy,Oracle_Accuracy-std
0,0,dce_search,0.074188,0.0,1011.693069,0.0,102.0,0.0,1.0,0.0,1.311108,0.0,0.544554,0.0,0.772277,0.0
