Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PySpark GBT Issue #2700

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

weishengtoh
Copy link

@weishengtoh weishengtoh commented Sep 22, 2022

Fix PySpark GBT Issue [fix #884 , fix #2480 ]

ISSUE:

  • Running PySpark GBT models sometimes causes the shap package to fail with the error message:
The background dataset you provided does not cover all the leaves in the model, so TreeExplainer cannot run with the feature_perturbation="tree_path_dependent" option! " Try providing a larger background dataset, no background dataset, or using feature_perturbation="interventional"."
  • Using feature_perturbation="interventional" as suggested does not work with pyspark models as the predict function is not implemented for pyspark models
# lines 1082 to 1085 in shap/shap/explainers/_tree.py
if self.model_type == "pyspark":
  #import pyspark
  # TODO: support predict for pyspark
   raise NotImplementedError("Predict with pyspark isn't implemented. Don't run 'interventional' as feature_perturbation.")
  • The feature_perturbation="tree_path_dependent" is failing due to a check in the code that is meant to ensure that background dataset lands in every leaf.
# lines 1031 to 1033 in shap/shap/explainers/_tree.py
# ensure that the passed background dataset lands in every leaf
if np.min(self.trees[i].node_sample_weight) <= 0:
    self.fully_defined_weighting = False
  • This should by right pass in all cases, considering that no background dataset is required for feature_perturbation="tree_path_dependent".

  • In some cases for pyspark gbt models, fully_defined_weighting is incorrectly set to False. fully_defined_weighting is determined by the values of node_sample_weight, which is determined by the code below:

# line 1199 in shap/shap/explainers/_tree.py
self.node_sample_weight[index] = node.impurityStats().count() #weighted count of element trough this node
  • node.impurityStats() returns a GiniCalculator, and the method .count() should return a float instead of int.
    See source

pysparkgbtshapfix

  • However, if you create a pyspark GBT model and obtain the values for node.impurityStats().count(),
    you will notice that the values has been rounded down to int.
  • node.impurityStats().count() should return the same values as sum([e for e in node.impurityStats().stats()]) if you follow the image above. It is however rounding down the values, and in some cases values greater than 0 and less than 1 are rounded down to 0.
  • This causes the self.fully_defined_weighting to return False, even when the values are clearly not zero.

SOLUTION:

  • Avoid using node.impurityStats().count(). Replace with sum([e for e in node.impurityStats().stats()]) which does exactly the same, but retain the value as float.

@venser12
Copy link

venser12 commented Nov 13, 2023

@slundberg @thatlittleboy @CloseChoice Could anyone merge this PR?

Thanks!

@CloseChoice
Copy link
Collaborator

CloseChoice commented Nov 13, 2023

@venser12 Thanks for following this up. Could you please add a test for this?
You could write the model to a temporary directory and load it back from there for the test.
Would also be great if you could resolve the conflicts

I will take a closer look later this week.

@@ -1196,7 +1196,7 @@ def buildTree(index, node):
self.values[index] = [node.prediction()] #prediction for the node
else:
self.values[index] = [e for e in node.impurityStats().stats()] #for gini: NDarray(numLabel): 1 per label: number of item for each label which went through this node
self.node_sample_weight[index] = node.impurityStats().count() #weighted count of element trough this node
self.node_sample_weight[index] = sum([e for e in node.impurityStats().stats()]) #weighted count of element trough this node
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why this is the correct way to do things? I am not deeply familiar with this code.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cam find the answer above.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because

 node.impurityStats().count()

rounds the values in PySpark models, so it return zero values when it was not zero values

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can please merge this PR. Thanks.

@venser12
Copy link

venser12 commented Nov 13, 2023

@CloseChoice, you can find the explanation at the top.

Here tou have a test:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import random

# Create Spark Session

spark = SparkSession.builder.appName("Shap").getOrCreate()

# Create DataFrame

data = [(random.randint(1, 100), random.randint(1, 50), random.randint(0, 1)) for _ in range(10)]

# Create Spark DataFrame

df = spark.createDataFrame(data, ["numeric", "numeric_2", "label"])

from pyspark.ml.classification import GBTClassifier

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline

# Assambling columns
assembler = VectorAssembler(inputCols=["numeric", "numeric_2"], outputCol='features')

# GBTClassifier
gbt_classifier = GBTClassifier(featuresCol="features", labelCol="label")

# Pipeline for each step
pipeline = Pipeline(stages=[assembler, gbt_classifier])

# Train
pipeline = pipeline.fit(df)
model = pipeline.stages[-1]
model.save("path_to_model")

import shap
import numpy as np
from pyspark.ml.classification import GBTClassificationModel

loaded_model = GBTClassificationModel.load("path_to_model")

explainer=shap.Explainer(loaded_model)

explainer.shap_values(np.array(df.select("numeric","numeric_2").collect()[0]))

@mriomoreno
Copy link

mriomoreno commented Nov 13, 2023

@CloseChoice , I don't know why this branch has conflicts. We only have changed one line. I tried to merged it manually and it worked:

image

@CloseChoice
Copy link
Collaborator

CloseChoice commented Nov 13, 2023

Why do you use GBTClassifier for fitting and GBTClassificationModel for loading?

To fix the conflicts please do:

git pull upstream master or whatever you set the github remote to

Edit:
If I am not mistaken one would need to install hadoop and set the "HADOOP_HOME" variable correctly to test this locally. Since we are not doing this I would suggest that we do not test this in our CI.

@mriomoreno
Copy link

mriomoreno commented Nov 14, 2023

@CloseChoice , I used GBTClassificatacionModel becasuse the Pipeline step for the model is GBTClassificationModel insted of GBTClassifier:

image

Moreover, SHAP only supports GBTClassificationModel. It does not support GBTClassifier objects:

Exception: The passed model is not callable and cannot be analyzed directly with the given masker! Model: GBTClassifier_636da3917b9d

So with GBTClassificationModel Im getting the same error:

# GBTClassificationModel object
model = pipeline.stages[-1]

# Save the GBTClassificationModel
model.save("path_to_model")

import shap
import numpy as np
from pyspark.ml import GBTClassificationModel

# Load the GBTClassificationModel
loaded_model = GBTClassificationModel.load("path_to_model")

image

For no reason I get zero values when it is clearly not zero values, so we must change this in _tree.py

I still think that changing that line will fix the error.

I also opened a new PR with no conflicts #3384

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Shap on pyspark doesn't work with a loaded model Error with Pyspark GBTClassifier
5 participants