In [1]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
#!pip install lime
import lime
import lime.lime_tabular
import shap

dataset = datasets.load_iris()

In [2]:
X, y = dataset.data, dataset.target

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify = y, shuffle=True)

In [4]:
from sklearn.linear_model import LogisticRegression
lr_model = LogisticRegression(C=0.01)

In [5]:
lr_model.fit(X_train, y_train)

LogisticRegression(C=0.01)

In [6]:
y_preds = lr_model.predict(X_test)

In [7]:
(y_preds == y_test).mean()

0.8444444444444444

In [8]:
%%javascript
require.config({
    paths: {
        d3: "https://d3js.org/d3.v6.min"
     }
});

require(["d3"], function(d3) {
    window.d3 = d3;
});

<IPython.core.display.Javascript object>

In [9]:
def create_explain_target_function(model, X_test, class_names, feature_names):
    def get_explain_instance(comm, open_msg):
        @comm.on_msg
        def _recv(msg):
            try:
                instance = msg['content']['data']['explain_instance']
                #explainer = explain_instance.lime['explainer']
                with open('lime_explainer','rb') as f:
                    lime_explainer = dill.load(f)
                    
                with open('shap_explainer','rb') as f:
                    shap_explainer = dill.load(f)
                    
                with open('shap_values','rb') as f:
                    shap_values = dill.load(f)
                explanation = lime_explainer.explain_instance(X_test[int(instance['instance_index'])], lr_model.predict_proba, num_samples=int(instance['lime']['n_samples']), labels = list(range(len(class_names))))
                lime_res ={}
                shap_res={}
                for index, i in enumerate(class_names):
                    lime_res[i] = dict(explanation.as_list(label=index))
                    shap_res[i] = dict(zip(feature_names, list(shap_values[index][int(instance['instance_index']), :])))
                
                comm.send({'message':"LIME explanation obtained successfully!", "lime_explanation":lime_res, "shap_explanation": shap_res})
            except Exception as e:
                exception_string  = str(e)
                comm.send({'message':"Error {0}\n{1}\n{2}".format(exception_string, msg, traceback.format_exc())})
    return get_explain_instance

get_ipython().kernel.comm_manager.register_target('get_explain_instance', create_explain_target_function(lr_model.predict_proba, X_test, class_names=dataset.target_names, feature_names=dataset.feature_names))

In [10]:
import traceback
import sys
import dill

def create_config_target_function(X_train, feature_names, class_names, model):
    explainer = None
    def get_configurations(comm, open_msg):
        @comm.on_msg
        def _recv(msg):
            try:
                print(msg)
                config = msg['content']['data']['config']
                discretize_continuous  = config['lime']['discretize_continuous']
                lime_explainer = lime.lime_tabular.LimeTabularExplainer(X_train,discretize_continuous =discretize_continuous , feature_names= feature_names, class_names = class_names )
                shap_explainer = shap.LinearExplainer(model, X_train, feature_dependence="independent")
                shap_values = shap_explainer.shap_values(X_test)
                
                with open('lime_explainer','wb') as f:
                    dill.dump(lime_explainer,f)
                    
                with open('shap_explainer','wb') as f:
                    dill.dump(shap_explainer,f)
                    
                with open('shap_values','wb') as f:
                    dill.dump(shap_values,f)
                    
                #get_ipython().kernel.comm_manager.register_target('get_explain_instance', create_explain_target_function(model.predict_proba, X_test, class_names=dataset.target_names,feature_names=feature_names, lime_explainer=lime_explainer, shap_explainer=shap_explainer, shap_values=shap_values))
                
                comm.send({'message':"LIME and SHAP explainer created successfully!{0},{1}".format(discretize_continuous,config['lime']['discretize_continuous']), "class_names":list(lime_explainer.class_names)})
            except Exception as e:
                exception_string  = str(e)
                comm.send({'message':"Error {0}\n {1} \n {2}".format(exception_string, traceback.format_exc(), sys.exc_info()[2])})
            #comm.send({'message':'hi'})
    return get_configurations

get_ipython().kernel.comm_manager.register_target('get_configurations', create_config_target_function(X_train, dataset.feature_names, dataset.target_names, lr_model))

In [11]:
from string import Template
from IPython.display import display, HTML

js_string = Template("""
<html>
<head>
<script>


    
    
function main(){
  var scores = new Object()

  const lime_div_id = "#lime_graph_div"
  const shap_div_id = "#shap_graph_div"
  
  create_lime_graph(lime_div_id, scores)
  create_shap_graph(shap_div_id, scores)
  
  let comm_config = Jupyter.notebook.kernel.comm_manager.new_comm("get_configurations")
  let comm_explain = Jupyter.notebook.kernel.comm_manager.new_comm("get_explain_instance")
  let lime_explainer = null
  
  let lime_res_arrays = null
  let class_names = null
  
  let shap_res_arrays = null
  
  
  
  comm_config.on_msg(function(msg){
      res = msg.content.data
      console.log(res)
      class_names = res['class_names']
      update_class_dropdown(class_names)
  })
  
  comm_explain.on_msg(function(msg){
      console.log(msg.content.data)
      res = msg.content.data
      
      lime_res_arrays = res['lime_explanation']
      shap_res_arrays = res['shap_explanation']
      
      selected_class = document.getElementById('class').value
      
      console.log(res)
      
      create_lime_graph(lime_div_id, lime_res_arrays[selected_class])
      create_shap_graph(shap_div_id, shap_res_arrays[selected_class])
  })
  
  document.getElementById("configure").addEventListener("click", function(){
  
      config = {'lime':{}, 'shap':{}}
      var discretize_continuous = document.getElementById("lime_discretize_continuous").checked
      config.lime.discretize_continuous = discretize_continuous
      console.log(config)
      comm_config.send({'config':config})
      console.log('finished config')
      
  })
  
  
  
  document.getElementById("explain").addEventListener("click", function(){
      explain_instance = {'lime':{}, 'shap':{}}
      explain_instance.instance_index = document.getElementById("instance_index").value
      explain_instance.lime.n_samples = document.getElementById("lime_n_samples").value
      
      comm_explain.send({'explain_instance':explain_instance})
  
  })
  
  
  
  document.getElementById("class").addEventListener("click", function(){
  var selected_class = document.getElementById("class").value
  create_lime_graph(div_id, lime_res_arrays[selected_class])
  })
  
  
  
}

function removeAll(selectBox) {
    while (selectBox.options.length > 0) {
        selectBox.remove(0);
    }
}

function update_class_dropdown(class_names){
    select_object = document.getElementById("class")
    removeAll(select_object)
    console.log(class_names)
    
    
    for(var i=0; i<class_names.length; i++){
        var option_object = document.createElement("option")
        var option_text = document.createTextNode(class_names[i])
        option_object.appendChild(option_text)
        option_object.setAttribute('value', class_names[i])
        select_object.appendChild(option_object)
    }
    
    
    console.log("DONE")
}

function create_lime_graph(div_id, scores){

  const lime_graph = ({
  'height' : 150,
  'width' : 500,
  'margins' : {
    'top' : 10,
    'bottom' : 50,
    'left' : 10,
    'right' : 10
  }})
  
  graph_div = document.getElementById(div_id.slice(1))
  graph_div.innerHTML = ''
  const svg = d3.select(div_id).append("svg")
  
  svg.attr('height', lime_graph.height + lime_graph.margins.top + lime_graph.margins.bottom)
  svg.attr('width', lime_graph.width + lime_graph.margins.left + lime_graph.margins.right)

  const midpoint = lime_graph.width/2
  const heightBar = 15

  const graph = svg.append('g')
  .attr('x', lime_graph.margins.left)
  .attr('y', lime_graph.margins.top)
  .attr('height', lime_graph.height )
  .attr('width', lime_graph.width )

  graph.append('line')
  .attr('x1', midpoint)
  .attr('x2', midpoint)
  .attr('y1', 0)
  .attr('y2', lime_graph.height)
  .attr('stroke', 'red')
  .attr('stroke-width', 1)

  const xScale = d3.scaleLinear()
  .domain([-1,1])
  .range([-150, 150])

  const yScale = d3.scaleBand()
  .domain(Object.keys(scores))
  .range([0,lime_graph.height])

  const xAxis = d3.axisBottom(xScale)

  console.log(scores)
  console.log(Object.keys(scores))

  graph.append('g')
  .selectAll('rect')
  .data(Object.keys(scores))
  .join('rect')
  .attr('x', row => {
    if(scores[row] < 0){
      return midpoint + xScale(scores[row])
    }else{
      return midpoint}
  })
  .attr('y', row =>  yScale(row))
  .attr('height', heightBar)
  .attr('width', row => {
    if(scores[row] > 0){
      return xScale(scores[row])
    }else{
      return -xScale(scores[row])}
  })
  .attr('fill', 'steelblue')

  graph.append('g')
  .selectAll('text')
  .data(Object.keys(scores))
  .join('text')
  .attr('x',  row => {
    if(scores[row] > 0){
      return midpoint + xScale(scores[row]) + 5
    }else{
      return midpoint + xScale(scores[row]) - 5}
  })
  .attr('y', row => yScale(row) + heightBar)
  .attr('text-anchor', row=> {
    if(scores[row] > 0){
      return 'start'
    }else{
      return 'end'
  }})
    .attr('dominant-baseline','center')
    .text(row => row + ' : ' + scores[row].toFixed(2))

  graph.append('g')
  .attr('transform', `translate(${midpoint},${lime_graph.height})`)
  .call(xAxis)
  .append('text')
  .attr('transform', `translate(${midpoint/8},${30})`)
  .attr('font-size', 12)
  .attr('stroke', 'black')
  .text('LIME Explanation')
  
  //return svg.node()
  
}

function create_shap_graph(div_id, scores){

  const shap_graph = ({
  'height' : 150,
  'width' : 500,
  'margins' : {
    'top' : 10,
    'bottom' : 50,
    'left' : 10,
    'right' : 10
  }})
  
  graph_div = document.getElementById(div_id.slice(1))
  graph_div.innerHTML = ''
  const svg = d3.select(div_id).append("svg")
  
  svg.attr('height', shap_graph.height + shap_graph.margins.top + shap_graph.margins.bottom)
  svg.attr('width', shap_graph.width + shap_graph.margins.left + shap_graph.margins.right)

  const midpoint = shap_graph.width/2
  const heightBar = 15

  const graph = svg.append('g')
  .attr('x', shap_graph.margins.left)
  .attr('y', shap_graph.margins.top)
  .attr('height', shap_graph.height )
  .attr('width', shap_graph.width )

  graph.append('line')
  .attr('x1', midpoint)
  .attr('x2', midpoint)
  .attr('y1', 0)
  .attr('y2', shap_graph.height)
  .attr('stroke', 'red')
  .attr('stroke-width', 1)

  const xScale = d3.scaleLinear()
  .domain([-1,1])
  .range([-150, 150])

  const yScale = d3.scaleBand()
  .domain(Object.keys(scores))
  .range([0,shap_graph.height])

  const xAxis = d3.axisBottom(xScale)

  console.log(scores)
  console.log(Object.keys(scores))

  graph.append('g')
  .selectAll('rect')
  .data(Object.keys(scores))
  .join('rect')
  .attr('x', row => {
    if(scores[row] < 0){
      return midpoint + xScale(scores[row])
    }else{
      return midpoint}
  })
  .attr('y', row =>  yScale(row))
  .attr('height', heightBar)
  .attr('width', row => {
    if(scores[row] > 0){
      return xScale(scores[row])
    }else{
      return -xScale(scores[row])}
  })
  .attr('fill', 'steelblue')

  graph.append('g')
  .selectAll('text')
  .data(Object.keys(scores))
  .join('text')
  .attr('x',  row => {
    if(scores[row] > 0){
      return midpoint + xScale(scores[row]) + 5
    }else{
      return midpoint + xScale(scores[row]) - 5}
  })
  .attr('y', row => yScale(row) + heightBar)
  .attr('text-anchor', row=> {
    if(scores[row] > 0){
      return 'start'
    }else{
      return 'end'
  }})
    .attr('dominant-baseline','center')
    .text(row => row + ' : ' + scores[row].toFixed(2))

  graph.append('g')
  .attr('transform', `translate(${midpoint},${shap_graph.height})`)
  .call(xAxis)
  .append('text')
  .attr('transform', `translate(${midpoint/8},${30})`)
  .attr('stroke', 'black')
  .attr('font-size', 12)
  .text('SHAP values')
  
  
  //return svg.node()
  
}
main()
</script>
</head>


<body>
<div id="parameters_div" lstyle="border:2px solid black;">


    <div id= "lime_config" style="display:inline-block;border:1px solid black; width:500px" >
        <b>LIME Configurations:</b></br>
        <label for="lime_discritize_continuous">Descritise Continuous</label>
        <input type='checkbox' id ='lime_discretize_continuous' size=5></br>
    </div>

    <div id= "shap_config" style="display:inline-block; border:1px solid black; width:470px" >
        <b>SHAP Configurations: </b></br>
        <label for="shap_feature_perturbation">Feature Perturbation</label>
        <select id ='shap_feature_perturbation'>
        <option value='interventional' name='Interventional'>Interventional</option>
        <option value='correlation_dependent' name='Correlation Dependent'>Correlation Dependent</option>
        
        </select></br>
    </div>
    <center><input type="button" id="configure" value="Configure"> <br></center>

    <div id="lime_instance_specs" style="display:inline-block; border:1px solid black; width:500px" >
        <b>LIME Paramters:</b></br>
        <label for "lime_n_samples">N Samples  :</label>
        <input  id ='lime_n_samples' size=5 value=100>
    </div>
    
    <div id="shap_instance_specs" style="display:inline-block; border:1px solid black; width:470px" >
        <b>SHAP Paramters:</b></br>
        <label for "lime_n_samples">N Samples  :</label>
        <input  id ='lime_n_samples' size=5 value=100>
        
    </div>

    <hr size="5" width="5">
    <center>
    <label for "instance_index">Instance Index  :</label>
    <input type='text' id='instance_index' size=5></br>
    
    <b>Configure & Enter Instance to explain</b><br>
    <br>
    <input type="button" id="explain" value="Explain"> 
    </center>
    <div id='output'> 
        Explanation for class: <select id='class'></select>
    </div>


</div>
</br>
</br>
<div id="lime_graph_div" style="display:inline-block;width:470px; border:1px solid black;">
</div>
<div id="shap_graph_div" style="display:inline-block;width:470px; border:1px solid black;">
</div>
</body>
</html>


""")

js_string = js_string.safe_substitute()
html_string = HTML(js_string)
_ = display(html_string)