### Graphis SHAP

In [79]:
### CON PLOTLY
import pandas as pd
import shap
import plotly.graph_objects as go
import plotly.express as px
import xgboost

def importance(shap_values, class_index: int, threshold: float = 0.01) -> pd.DataFrame:
    values_class = shap_values.values[:, class_index]
    shap_df = pd.DataFrame({
        'feature': shap_values.feature_names,
        'value': values_class
    })
    shap_df = shap_df[shap_df['value'].abs() > threshold]
    return shap_df

def graph_class_controller(shap_values_id) -> int:
    return shap_values_id.base_values.argmax()

def graphics_shap(shap_values: shap._explanation.Explanation, color_scheme: str, class_names: list, 
                  id_example: [list, None] = [0], by_predicted_class: bool = True, type_shap:str = "waterfall") -> go.Figure: # type: ignore
    try:
        fig = go.Figure()
        charFilter = id_example if id_example is not None else range(len(shap_values))

        for example in charFilter:
            shap_value_id = shap_values[example]
            if by_predicted_class:
                class_id = graph_class_controller(shap_value_id)
                target_class = class_names[class_id]
                shap_df = importance(shap_value_id, class_id)
                fig = settings_bars(fig, type_shap, shap_df, f'example{example}-{target_class}', color_scheme)
            else:
                for index, class_name in enumerate(class_names):
                    shap_df = importance(shap_value_id, index)
                    fig = settings_bars(fig,type_shap,shap_df, f'example{example}-{class_name}', color_scheme)
        
        fig = settings_shap(fig, title='SHAP Feature Importance')
        return fig
    except Exception as e:
        print(f'Error when graphing: {e}')

def settings_bars(fig, type_shap, shap_df: pd.DataFrame, example_name: str, color_scheme: [list, str]) -> go.Figure: # type: ignore
    increasing, decreasing = _get_color_scheme(color_scheme)
    settings = {
        "waterfall": {
            "name": example_name,
            "orientation": 'h',
            "measure": ['relative'] * len(shap_df),  
            "x": shap_df['value'],
            "y": shap_df['feature'],                    
            "text": shap_df['value'].apply(lambda x: f'{x:.2f}'),        
            "connector": {"line": {"width": 0}},  
            "increasing": dict(marker=dict(color=increasing, line=dict(width=0))),
            "decreasing": dict(marker=dict(color=decreasing, line=dict(width=0)))
        },
        "bar": {
            "title": example_name,
            "orientation": 'h',
            "x": shap_df['value'],                     
            "y": shap_df['feature'],                   
            "text": shap_df['value'].apply(lambda x: f'{x:.2f}'),  
            "marker": {
                "color": shap_df['value'].apply(lambda x: increasing if x > 0 else decreasing),
                "line": dict(width=0)  
                }
        }
    }
    if type_shap == "waterfall":
        fig.add_trace(go.Waterfall(**settings['waterfall']))
    elif type_shap == "bar":
        fig.add_trace(go.Bar(**settings['bar']))
    else:
        raise ValueError(f"Type of graph not supported: {type_shap}")
    return fig

def settings_shap(fig, width=None, height=None, title: str = '', xaxis_title: str = '', 
                  yaxis_title: str = '', show_legend: bool = True) -> go.Figure:
    fig.update_layout(
        title=title,
        yaxis_title=yaxis_title,
        xaxis_title=xaxis_title,
        width=width,
        height=height,
        showlegend=show_legend,
        xaxis=dict(autorange=True),
        yaxis=dict(tickangle=0),
        barmode='relative'
    )
    return fig

def _get_color_scheme(color_scheme: [list, str]) -> tuple: # type: ignore
    if isinstance(color_scheme, str):
        colors = getattr(px.colors.diverging, color_scheme)
        return colors[0], colors[1]
    elif isinstance(color_scheme, list) and len(color_scheme) == 2:
        return color_scheme[0], color_scheme[1]
    else:
        raise ValueError("The color scheme must have hexadecimal values or one of these supported schemes: 'RdBu', 'GnPR', 'CyPU', 'PkYg', 'DrDb', 'LpLb', 'YlDp', 'OrId'")


def shap_values(model: xgboost.XGBClassifier, X_test: pd.DataFrame):
    try:
        explainer = shap.Explainer(model)
        return explainer(X_test)
    except Exception as e:
        print(f'Error in the calculation of SHAP values: {e}')

def mapping_labels(Y_train: pd.Series, unique_classes):
    map = {class_name: idx for idx, class_name in enumerate(unique_classes)} 
    return Y_train.replace(map)

def training(classifier, X, Y, **kwargs):
    try:
        return classifier(**kwargs).fit(X, Y)
    except Exception as e:
        print(f'Error in training the model: {e}')

def export_graph_to_html(fig: go.Figure, filename: str):
    try:
        fig.write_html(filename)
    except Exception as e:
        print(f'{e}')
    
def load_data(route_input, patterns):
    data = {}
    files = os.listdir(route_input)
    
    for file in files:
        path_file = os.path.abspath(os.path.join(route_input, file))
        for key, pattern in patterns.items():
            if pattern in file:
                data[key] = pd.read_pickle(path_file)
                break  
    return data

# Carga de datos
if __name__ == '__main__':
    uploaded_data = load_data("C:\\Users\\matrix\\Carenne\\data", {'y_train': 'y_train', 'x_train': 'x_train', 'x_test': 'x_test'})
    Y, X, X_test = uploaded_data.get('y_train'), uploaded_data.get('x_train'), uploaded_data.get('x_test')
    
    unique_classes = Y.unique().tolist()
    route_html = 'shap_graph.html'
    color_scheme = ['rgb(255, 0, 0)', '#33BBFF']

    Y_map = mapping_labels(Y, unique_classes)
    model = training(xgboost.XGBClassifier, X, Y_map)

    shap_v = shap_values(model, X_test)
    type_result = True #

    fig = graphics_shap(shap_v, color_scheme, class_names=unique_classes, 
                        id_example=[0, 1], by_predicted_class=type_result) 
    
    export_graph_to_html(fig, route_html)
    print(f"Gráfico exportado a {route_html}")


Gráfico exportado a shap_graph.html


### Obtener los datos de las visualizaciones SHAP

In [45]:
import pandas as pd
import shap
import plotly.graph_objects as go
import plotly.express as px
import xgboost
import os

class Graph_Shap:
    def __init__(self, x_train:pd.DataFrame, y_train:pd.Series, x_test:pd.DataFrame, viz_type:bool=True):
        self.x_train = x_train 
        self.x_test = x_test
        self.y_train = y_train
        self.viz_type_shap = viz_type # Controla el tipo de visualización de los values shap de cada muestra, ya sea por clase predicha o por cada tipo de clase
        self.unique_classes = y_train.unique().tolist()
        self.fig = None
    
    def __mapping_labels(self, Y_train) -> pd.Series:
        if Y_train.dtype == "object":
            map = {class_name: idx for idx, class_name in enumerate(self.unique_classes)} 
            return Y_train.replace(map) # Downcasting behavior in `replace` is deprecated and will be removed in a future version.
        return Y_train
    
    def __training(self, classifier, X:pd.DataFrame, Y:pd.DataFrame, **kwargs:dict) -> xgboost.XGBClassifier:
        try:
            Y_map = self.__mapping_labels(Y)
            return classifier(**kwargs).fit(X, Y_map)
        except Exception as e:
            print(f'Error in training the model: {e}')
    
    def __get_scheme(self) -> tuple: # type: ignore
        try:
            if isinstance(self.color_scheme, str):
                colors = getattr(px.colors.diverging, self.color_scheme)
                return colors[0], colors[1]
            elif isinstance(self.color_scheme, list) and len(self.color_scheme) == 2:
                return self.color_scheme[0], self.color_scheme[1]
        except ValueError:
            print("The color scheme must have hexadecimal values or one of these supported schemes: 'RdBu', 'GnPR', 'CyPU', 'PkYg', 'DrDb', 'LpLb', 'YlDp', 'OrId'")
            
    def __importance(self, shap_values:shap.Explainer, class_index:int, threshold:float = 0.01) -> pd.DataFrame:
        values_class = shap_values.values[:, class_index]
        shap_df = pd.DataFrame({
            'feature': shap_values.feature_names,
            'value': values_class
        })
        shap_df = shap_df[shap_df['value'].abs() > threshold]
        return shap_df
    
    def __settings_bars(self, type_shap, shap_df: pd.DataFrame, example_name: str) -> go.Figure: # type: ignore
        increasing, decreasing = self.__get_scheme()
        settings = {
            "waterfall": {
                "name": example_name,
                "orientation": 'h',
                "measure": ['relative'] * len(shap_df),  
                "x": shap_df['value'],
                "y": shap_df['feature'],                    
                "text": shap_df['value'].apply(lambda x: f'{x:.2f}'),        
                "connector": {"line": {"width": 0}},  
                "increasing": dict(marker=dict(color=increasing, line=dict(width=0))),
                "decreasing": dict(marker=dict(color=decreasing, line=dict(width=0)))
            },
            "bar": {
                "title": example_name,
                "orientation": 'h',
                "x": shap_df['value'],                     
                "y": shap_df['feature'],                   
                "text": shap_df['value'].apply(lambda x: f'{x:.2f}'),  
                "marker": {
                    "color": shap_df['value'].apply(lambda x: increasing if x > 0 else decreasing),
                    "line": dict(width=0)  
                    }
            }
        }
        if type_shap == "waterfall":
            self.fig.add_trace(go.Waterfall(**settings['waterfall']))
        elif type_shap == "bar":
            self.fig.add_trace(go.Bar(**settings['bar']))
        else:
            raise ValueError(f"Type of graph not supported: {type_shap}")
        return self.fig

    def __settings_shap(self, width=None, height=None, title: str = '', xaxis_title: str = '', 
                  yaxis_title: str = '', show_legend: bool = True) -> go.Figure:
        self.fig.update_layout(
            title=title,
            yaxis_title=yaxis_title,
            xaxis_title=xaxis_title,
            width=width,
            height=height,
            showlegend=show_legend,
            xaxis=dict(autorange=True),
            yaxis=dict(tickangle=0),
            barmode='relative'
        )
        return self.fig
    
    def values_shap(self, model:xgboost.XGBClassifier, **kwargs_model:dict) -> shap.Explainer:
        try:
            model = self.__training(model, self.x_train, self.y_train, **kwargs_model)
            explainer = shap.Explainer(model)
            return explainer(X_test)
        except Exception as e:
            print(f'Error in the calculation of SHAP values: {e}')
    
    '''def __manage_figure(self):
        if self.viz_graph:
            if not hasattr(self, 'fig') or self.fig is None:
               self.fig = go.Figure()
        else:
            self.fig = go.Figure()
        return self.fig
        '''
    
    def graphics_shap(self, shap_values: shap._explanation.Explanation, color_scheme: str, id_example: [list, None] = [0], type_shap:str = "waterfall") -> go.Figure: # type: ignore
        self.color_scheme = color_scheme
        self.html_parts = []
        try: 
            charFilter = id_example if id_example is not None else range(len(shap_values))
            for example in charFilter:
                #self.fig = self.__manage_figure()
                self.fig = go.Figure()
                self.shap_value_id = shap_values[example]
                if self.viz_type_shap:
                    class_id = self.shap_value_id.base_values.argmax() # Conseguir la clase predicha
                    shap_df = self.__importance(self.shap_value_id, class_id) # Dejar solo las caracteristicas importantes que benefician a la clase predicha 
                    self.fig = self.__settings_bars(type_shap, shap_df, f'example{example}-{self.unique_classes[class_id]}')
                else:
                    for index, class_name in enumerate(self.unique_classes): # Valores shap para cada clase del conjunto en cada muestra 
                        shap_df = self.__importance(self.shap_value_id, index)
                        self.fig = self.__settings_bars(type_shap,shap_df, f'example{example}-{class_name}')
                    
                self.fig = self.__settings_shap(title='SHAP Feature Importance')
                self.html_parts.append(self.fig)

            return self.html_parts
        except Exception as e:
            print(f'Error when graphing: {e}')
    
def load_data(route_input:str, patterns:dict={'y_train': 'y_train', 'x_train': 'x_train', 'x_test': 'x_test'}) -> dict:
    data = {}
    files = os.listdir(route_input)
    for file in files:
        path_file = os.path.abspath(os.path.join(route_input, file))
        for key, pattern in patterns.items():
            if pattern in file:
                data[key] = pd.read_pickle(path_file)
                break  
    return data

def select_examples(mode: str, x_test, range_example:int=5, specific_example:list=[0, 1]) -> list:
    if mode == "limit":
        id_example = [ide for ide in range(range_example)]
    elif mode == "total":
        id_example = [ide for ide in range(len(x_test))]
    elif mode == "specific":
        id_example = [ide for ide in specific_example]
    return id_example


if __name__ == '__main__':
    viz_type = False # Controla el comportamiento de la importancia de x clase(False) o solo por la clase predicha(True)
    examples = 50
    current_path = os.getcwd()
    
    uploaded_data = load_data(f'{current_path}\\data')
    Y, X, X_test = uploaded_data.get('y_train'), uploaded_data.get('x_train'), uploaded_data.get('x_test')
    
    color_scheme = ['rgb(255, 0, 0)', '#33BBFF']

    test_shap = Graph_Shap(X, Y, X_test, viz_type=viz_type) 
    shap_v = test_shap.values_shap(xgboost.XGBClassifier) 
    
    id_example = select_examples("limit", X_test, range_example=examples)

    html_parts = test_shap.graphics_shap(shap_v, color_scheme, id_example, type_shap="waterfall") 


Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`



### Concatenar gráficos SHAP en una sola salida HTML

In [46]:
import plotly
import json

class Html_master:
    def __init__(self, html_parts) -> None:
        self.html_parts = html_parts
        
        # Definir el HTML maestro con el script para Plotly
        self.html_master = """
            <!DOCTYPE html>
            <html lang="en">
            <head>
                <meta charset="UTF-8">
                <meta name="viewport" content="width=device-width, initial-scale=1.0">
                <title>SHAP Plots</title>
                <!-- Cargar Plotly.js -->
                <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
                <style>{style}</style>
                <script>
                    function resizeGraph(graphId) {
                        // Obtener el contenedor del gráfico
                        var container = document.querySelector('.container-shap');

                        // Obtener las dimensiones del contenedor
                        var containerSize = container.getBoundingClientRect();
                        var width = containerSize.width;
                        var height = containerSize.height;

                        // Cambiar el tamaño del gráfico dinámicamente usando las dimensiones del contenedor
                        Plotly.relayout(graphId, {
                            'width': width,
                            'height': height
                        });
                    }

                    function showGraph(graphId) {
                        var graphs = document.getElementsByClassName('shap-graph');
                        for (var i = 0; i < graphs.length; i++) {
                            graphs[i].style.display = 'none';  // Ocultar todos los gráficos
                        }
                        var currentGraph = document.getElementById(graphId);
                        currentGraph.style.display = 'block';  // Mostrar solo el gráfico seleccionado

                        resizeGraph(graphId);
                    }

                    // Evento para ajustar el tamaño del gráfico cuando la ventana cambie de tamaño
                    window.addEventListener('resize', function () {
                        var graphId = document.querySelector('.shap-graph[style*="block"]').id;  
                        resizeGraph(graphId);
                    });
                </script>
            </head>
            <body>
                <h1>Selecciona un gráfico SHAP</h1>
                <select id="graph-selector" onchange="showGraph(this.value)">
                    {options}
                </select>
                <div class="container-shap" style= "width: 100%; height: 100%">
                    {graphs}
                </div> 
            </body>
            </html>
            """
        
        self.style = """
            body {
                margin: 0;
                padding: 0;
                font-family: 'Arial', sans-serif;
                background-color: #fff; 
                height: 100vh;
                display: flex;
                flex-direction: column;
                align-items: center;
            }
            
            #graph-selector {
                background-color: #0272eb; 
                color: #ffffff;
                border: none;
                padding: 8px 12px;
                cursor: pointer;
                font-size: 14px;
                text-align: center;
            }

            #graph-selector:hover {
                background-color: #0056b3; 
            }

            #graph-selector option {
                background-color: #ffffff; 
                color: black;
            }
            
            .shap-graph {
                flex-grow: 1;
                width: 100%;
                height: 100%; 
                background-color: #fff;
                justify-content: center;
                align-items: center;
                padding: 0;
            } 
            """
        
        options = self.generate_options()
        graphs = self.generate_graphs()

        self.html_master = self.html_master.replace("{options}", options)
        self.html_master = self.html_master.replace("{graphs}", graphs)
        self.html_master = self.html_master.replace("{style}", self.style)
        
        with open("shap_combined_plots.html", "w") as f:
            f.write(self.html_master)
    
    def generate_options(self):
        return ''.join(f'<option value="graph{i}">Gráfico {i}</option>\n' for i in range(len(self.html_parts)))

    def generate_graphs(self):
        graphs = []
        for i, fig in enumerate(self.html_parts):
            display_style = "block" if i==0 else "none"
            graph_json = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)  # Convertir la figura a JSON para Plotly
            graphs.append(f'''
                <div id="graph{i}" class="shap-graph" style="display: {display_style}">
                    <script type="text/javascript">
                        Plotly.newPlot('graph{i}', {graph_json}); 
                    </script>
                </div>
            ''')
        return ''.join(graphs)

if __name__ == '__main__':
    Html_master(html_parts)
