# Adding a New Explainer to GRETEL

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [9]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path = '/NFSHOME/mprado/CODE/GRETEL/examples/config/config_autism_custom-oracle_dummy_explainer.json'
output_file_path = '/NFSHOME/mprado/CODE/GRETEL/output/asd_custom_oracle/DummyExplainer/results_run-0.json'

## Creating a new Explainer

 Creating the explainer

In [10]:
from src.evaluation.evaluation_metric_base import EvaluationMetric
from src.explainer.explainer_base import Explainer
from src.dataset.dataset_base import Dataset
from src.oracle.oracle_base import Oracle

class DummyExplainer(Explainer):
    """This Dummy Explainer search for the first counterfactual instance in the dataset and returns it"""

    def __init__(self, id, config_dict=None) -> None:
        super().__init__(id, config_dict)
        self._name = 'DummyExplainer'


    def explain(self, instance, oracle: Oracle, dataset: Dataset):
        l_input_inst = oracle.predict(instance)

        # if the method does not find a counterfactual example returns the original graph
        min_counterfactual = instance

        for d_inst in dataset.instances:
            
            l_data_inst = oracle.predict(d_inst)

            if (l_input_inst != l_data_inst):
                min_counterfactual = d_inst

                return min_counterfactual
        
        return min_counterfactual


Creating a custom ExplainerFactory that extends the base class with the new explainer

In [11]:
from src.explainer.explainer_factory import ExplainerFactory
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory

class CustomExplainerFactory(ExplainerFactory):

    def __init__(self, explainer_store_path):
        super().__init__(explainer_store_path)

    def get_explainer_by_name(self, explainer_dict, metric_factory : EvaluationMetricFactory) -> Explainer:
        explainer_name = explainer_dict['name']

        # Check if the explainer is DCE Search
        if explainer_name == 'dummy_explainer':
            # Returning the explainer
            return self.get_dummy_explainer(explainer_dict)
        else:
            return super().get_explainer_by_name(explainer_dict, metric_factory)

    def get_dummy_explainer(self, config_dict=None):
        result = DummyExplainer(self._explainer_id_counter, config_dict)
        self._explainer_id_counter += 1
        return result
            


In [12]:
ex_store_path = '/NFSHOME/mprado/CODE/GRETEL/data/explainers/'
ex_factory = CustomExplainerFactory(ex_store_path)

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=None, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=ex_factory, 
                                evaluation_metric_factory=None)
eval_manager.create_evaluators()
eval_manager.evaluate()

In [13]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "autism", "parameters": {}}, "oracle": {"name": "asd_custom_oracle", "parameters": {}}, "explainer": {"name": "dummy_explainer", "parameters": {}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "correctness", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.002142667770385742, 0.0005326271057128906, 0.0011496543884277344, 0.001150369644165039, 0.0011563301086425781, 0.001268625259399414, 0.0014653205871582031, 0.0011479854583740234, 0.0011401176452636719, 0.0011873245239257812, 0.0014743804931640625, 0.0014410018920898438, 0.0011568069458007812, 0.0011444091796875, 0.0014407634735107422, 0.0011823177337646484, 0.0011701583862304688, 0.0011718273162841797, 0.0011553764343261719, 0.0011591911315917969, 0.0011403560638427734, 0.0014922618865966797, 0.0011365413665771484, 0.001

In [14]:
from src.data_analysis.data_analyzer import DataAnalyzer

dtan = DataAnalyzer('/NFSHOME/mprado/CODE/GRETEL/output', '/NFSHOME/mprado/CODE/GRETEL/stats')
dtan.aggregate_data()
dtan.aggregate_runs()
dtan.create_tables_by_oracle_dataset()

In [15]:
import pandas as pd
results_table = pd.read_csv('/NFSHOME/mprado/CODE/GRETEL/stats/autism-asd_custom_oracle.csv')
results_table

Unnamed: 0.1,Unnamed: 0,explainer,runtime,Graph_Edit_Distance,Oracle_Calls,Correctness,Sparsity,Fidelity,Oracle_Accuracy
0,0,dce_search,0.059996,1011.693069,102.0,1.0,1.311108,0.544554,0.772277
1,1,dummy_explainer,0.001329,1077.356436,2.534653,1.0,1.396227,0.544554,0.772277
