In [1]:
import os
import pandas as pd
import random
from string import Template
from typing import Any, List, Tuple, Union

In [2]:
#######################################################
# CONSTANTS
#######################################################
LANG_ES = 'es'
LANG_EN = 'en'
NODE_IMPORTANCE = 'node_importance'
NODE_NAME = 'node_name'
QUALITY_MEASURE = 'quality_measure'
RAND = 'rand'
RANK = 'rank'
REASON = 'reason'
TARGET = 'target'

In [39]:
class Why:
    
    _LOCAL = 'local'
    _GLOBAL = 'global'
    _SEP_LAST = {
        LANG_ES: ' y ',
        LANG_EN: ' and '
    }
    
    def __init__(self, language: str, local_expl: pd.DataFrame, local_nodes: pd.DataFrame, why_elements: pd.DataFrame,
                 why_target: pd.DataFrame, why_templates: pd.DataFrame, quality_measure: pd.DataFrame,
                 n_local_features: int = 2, n_global_features: int = 2, min_quality: float = 0.0):
        self.language = language
        if self.language not in self._SEP_LAST.keys():
            raise Exception('Language {} not supported'.format(language))
        self.local_expl = local_expl
        self.local_nodes = local_nodes
        self.why_elements = why_elements
        self.why_target = why_target
        self.why_templates = why_templates
        self.quality_measure = quality_measure
        self.n_local_features = n_local_features
        self.n_global_features = n_global_features
        self.min_quality = min_quality
        
    def __build_template(self, items: Union[List, Tuple], sep=', ') -> str:
        sep_last = self._SEP_LAST[self.language]
        i_list = ['$' + i for i in (items if isinstance(items, list) else list(items))]
        return (sep.join(i_list[:-1]) + sep_last + i_list[-1]) if len(i_list) > 1 else i_list[0]
        
    def get_reason_why(self, key_column: str, key_value: Any, template_index: Union[str, int] = RAND) -> str:
        # Check the quality of the explainability values
        quality = self.quality_measure[self.quality_measure[key_column] == key_value][QUALITY_MEASURE].loc[0]
        if quality < self.min_quality:
            return self.why_templates.iloc[0, 0]
        
        # Build a dataframe with all the case information
        df = (self.local_expl[self.local_expl[key_column] == key_value][[key_column, TARGET]]
              .merge(self.local_nodes[self.local_nodes[key_column] == key_value][[key_column, NODE_NAME, NODE_IMPORTANCE]],
                     on=key_column, how='inner')
              .merge(self.why_elements
                     .rename(columns={'feature': NODE_NAME}),
                     on=NODE_NAME, how='inner')
              .merge(self.why_target
                     .rename(columns={'feature': NODE_NAME}),
                     on=[TARGET, NODE_NAME], how='inner', suffixes=['_' + self._LOCAL, '_' + self._GLOBAL])
              .sort_values(by=NODE_IMPORTANCE, ascending=False))

        # Build why sentence
        kw_local = dict([('v_' + self._LOCAL + '_' + str(i), v) for i, v in enumerate(df['reason_' + self._LOCAL].iloc[:self.n_local_features])])
        kw_global = dict([('v_' + self._GLOBAL + '_' + str(i), v) for i, v in enumerate(df['reason_' + self._GLOBAL].iloc[:self.n_global_features])])
        temp_local_explain = self.__build_template(items=kw_local)
        temp_global_explain = self.__build_template(items=kw_global)

        temp_idx_max = self.why_templates.size - 1
        temp_idx = random.randint(1, temp_idx_max) if template_index == RAND else min(template_index, temp_idx_max)
        temp_why_str = (Template(Template(self.why_templates.iloc[temp_idx, 0])
                                 .substitute(temp_local_explain=temp_local_explain,
                                             temp_global_explain=temp_global_explain,
                                             target=df[TARGET].iloc[0]))
                        .substitute(**kw_local, **kw_global)
                        .capitalize())
        return temp_why_str


In [35]:
# Read files as Pandas dataframes
# data_path = '../../examples/example_why_titanic
path_why = 'C:/datos/AIP/proyectos/XAIoGraphs/data/why'
path_explainer_output = os.path.join(path_why, 'tmp')
data_path = 'C:/datos/AIP/proyectos/XAIoGraphs/data/why/tmp'
local_expl = pd.read_csv(os.path.join(path_explainer_output, 'local_explainability.csv'))
local_nodes = pd.read_csv(os.path.join(path_explainer_output, 'local_graph_nodes.csv'))
qm = pd.read_csv(os.path.join(path_explainer_output, 'local_dataset_reliability.csv'))
why_elements = pd.read_csv(os.path.join(path_why, 'why_element.csv'), comment='#')
why_target = pd.read_csv(os.path.join(path_why, 'why_target.csv'), comment='#')
why_templates = pd.read_fwf(os.path.join(path_why, 'templates.csv'), header=None)

In [43]:
# local_expl[local_expl['id'] == 9]
# local_nodes[local_nodes['id'] == 9]
# local_expl.shape
# qm[qm['id'] == 9]
qm[['target']].value_counts()

target  
target_3    54
target_1    25
target_2    21
dtype: int64

In [33]:
# Instantiate Why object
why = Why(language='es', local_expl=local_expl, local_nodes=local_nodes, why_elements=why_elements, why_target=why_target,
          why_templates=why_templates, quality_measure=qm, n_local_features=3, n_global_features=3, min_quality=0.0)

In [34]:
# Request explanation
why.get_reason_why(key_column='id', key_value=9)

Index(['id', 'target', 'node_name', 'node_importance', 'reason_local',
       'reason_global'],
      dtype='object')
(0, 6)
target_1    1
Name: target, dtype: int64


IndexError: list index out of range