# Adding a new Dataset to GRETEL

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

In [8]:
from src.evaluation.evaluator_manager import EvaluatorManager

config_file_path =  module_path + '/examples/config/config_squares-triangles_trisqr_custom_oracle_dce.json'
output_file_path = module_path + '/examples/output/triangles_squares_custom_oracle/DCESearchExplainer/results_run-0.json'
output_folder = module_path + '/examples/output/'
stats_folder = module_path + '/examples/stats/'
datasets_folder = module_path + '/data/datasets/'

# Verifying that the paths are valid
(os.path.isfile(config_file_path), os.path.isfile(output_file_path), os.path.isdir(output_folder), os.path.isdir(stats_folder), 
os.path.isdir(datasets_folder))

(True, True, True, True, True)

## Creating the Squares-Triangles Dataset

The Squares-Triangles Dataset is a synthetic dataset generated on the fly that is composed by cycle graphs, some of them are triangles and the others are squares

### Creating the dataset

In [4]:
from src.dataset.data_instance_base import DataInstance
from src.dataset.dataset_base import Dataset

import networkx as nx
import numpy as np

class SquaresTrianglesDataset(Dataset):

    def __init__(self, id, config_dict=None) -> None:
        super().__init__(id, config_dict)
        self.instances = []

    def create_cycle(self, cycle_size, role_label=1):

        # Creating an empty graph and adding the nodes
        graph = nx.Graph()
        graph.add_nodes_from(range(0, cycle_size))

        # Adding the edges  of the graph
        for i in range(cycle_size - 1):
            graph.add_edges_from([(i, i + 1)])

        graph.add_edges_from([(cycle_size - 1, 0)])
        
        # Creating the dictionary containing the node labels 
        node_labels = {}
        for n in graph.nodes:
            node_labels[n] = role_label

        # Creating the dictionary containing the edge labels
        edge_labels = {}
        for e in graph.edges:
            edge_labels[e] = role_label

        # Returning the cycle graph and the role labels
        return graph, node_labels, edge_labels


    def generate_squares_triangles_dataset(self, n_instances):

        self._name = ('squares-triangles_instances-'+ str(n_instances))

        # Creating the empty list of instances
        result = []

        for i in range(0, n_instances):
            # Randomly determine if the graph is going to be a traingle or a square
            is_triangle = np.random.randint(0,2)

            # Creating the instance
            data_instance = DataInstance(id=self._instance_id_counter)
            self._instance_id_counter +=1

            i_name = 'g' + str(i)
            i_graph = None
            i_node_labels = None
            i_edge_labels = None

            # creating the instance properties specific for squares or triangles
            if(is_triangle):
                # Creating the triangle graph
                i_graph, i_node_labels, i_edge_labels = self.create_cycle(cycle_size=3, role_label=1)
                data_instance.graph_label = 1
            else:
                i_graph, i_node_labels, i_edge_labels = self.create_cycle(cycle_size=4, role_label=0)
                data_instance.graph_label = 0  

            # Creating the general instance properties
            data_instance.graph = i_graph
            data_instance.node_labels = i_node_labels
            data_instance.edge_labels = i_edge_labels
            data_instance.minimum_counterfactual_distance = 4
            data_instance.name = i_name

            result.append(data_instance)

        # return the set of instances
        self.instances = result

    

### Creating the DatasetFactory

In [5]:
from src.dataset.dataset_base import Dataset
from src.dataset.dataset_factory import DatasetFactory
import os
import shutil


class CustomDatasetFactory(DatasetFactory):

    def __init__(self, data_store_path) -> None:
        self._data_store_path = data_store_path
        self._dataset_id_counter = 0


    def get_dataset_by_name(self, dataset_dict) -> Dataset:

        dataset_name = dataset_dict['name']
        params_dict = dataset_dict['parameters']

        # Check if the dataset is a squares-triangles dataset
        if dataset_name == 'squares-triangles':
            if not 'n_inst' in params_dict:
                raise ValueError('''"n_inst" parameter containing the number of instances in the dataset
                 is mandatory for squares-triangles dataset''')

            return self.get_squares_triangles_dataset(params_dict['n_inst'], False, dataset_dict)
        else:
            # call the base method in to generate any of the originally supported datasets
            return super().get_dataset_by_name(dataset_dict)


    def get_squares_triangles_dataset(self, n_instances=300, regenerate=False, config_dict=None) -> Dataset:
        result = SquaresTrianglesDataset(self._dataset_id_counter, config_dict)
        self._dataset_id_counter+=1

        # Create the name an uri of the dataset using the provided parameters
        ds_name = ('squares-triangles_instances-'+ str(n_instances))
        ds_uri = os.path.join(self._data_store_path, ds_name)
        ds_exists = os.path.exists(ds_uri)

        # If regenerate is true and the dataset exists then remove it an generate it again
        if regenerate and ds_exists: 
            shutil.rmtree(ds_uri)

        # Check if the dataset already exists
        if(ds_exists):
            # load the dataset
            result.read_data(ds_uri)
        else:
            # Generate the dataset
            result.generate_squares_triangles_dataset(n_instances)
            result.generate_splits()
            result.write_data(self._data_store_path)
            
        return result

### Evaluating the DCE explainer in the new dataset

In [6]:
ds_store_path = datasets_folder
ds_factory = CustomDatasetFactory(ds_store_path)

# The run number is a way to differentiate many runs of the same configurations
eval_manager = EvaluatorManager(config_file_path, run_number=0, 
                                dataset_factory=ds_factory, 
                                embedder_factory=None, 
                                oracle_factory=None, 
                                explainer_factory=None, 
                                evaluation_metric_factory=None)
eval_manager.create_evaluators()
eval_manager.evaluate()

### Reading the results

In [9]:
with open(output_file_path, 'r') as rs_json_reader:
                results = rs_json_reader.read()

results

'{"config": {"dataset": {"name": "squares-triangles", "parameters": {"n_inst": 100}}, "oracle": {"name": "trisqr_custom_oracle", "parameters": {"embedder": {"name": "graph2vec", "parameters": {}}}}, "explainer": {"name": "dce_search", "parameters": {"graph_distance": {"name": "graph_edit_distance", "parameters": {}}}}, "metrics": [{"name": "graph_edit_distance", "parameters": {}}, {"name": "oracle_calls", "parameters": {}}, {"name": "correctness", "parameters": {}}, {"name": "sparsity", "parameters": {}}, {"name": "fidelity", "parameters": {}}, {"name": "oracle_accuracy", "parameters": {}}]}, "runtime": [0.0013515949249267578, 0.0009775161743164062, 0.0010020732879638672, 0.0008046627044677734, 0.0008766651153564453, 0.0007398128509521484, 0.0008320808410644531, 0.0006115436553955078, 0.0007717609405517578, 0.0006520748138427734, 0.0006070137023925781, 0.0006251335144042969, 0.0005903244018554688, 0.0007309913635253906, 0.0023086071014404297, 0.0009348392486572266, 0.000741481781005859

Creating results table

In [10]:
from src.data_analysis.data_analyzer import DataAnalyzer

dtan = DataAnalyzer(output_folder, stats_folder)
dtan.aggregate_data()
dtan.aggregate_runs()
dtan.create_tables_by_oracle_dataset()

In [11]:
import pandas as pd
results_table = pd.read_csv(stats_folder + 'squares-triangles-trisqr_custom_oracle.csv')
results_table

Unnamed: 0.1,Unnamed: 0,explainer,runtime,runtime-std,Graph_Edit_Distance,Graph_Edit_Distance-std,Oracle_Calls,Oracle_Calls-std,Correctness,Correctness-std,Sparsity,Sparsity-std,Fidelity,Fidelity-std,Oracle_Accuracy,Oracle_Accuracy-std
0,0,dce_search,0.000888,0.0,4.0,0.0,101.0,0.0,1.0,0.0,0.578333,0.0,1.0,0.0,1.0,0.0
