New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PySpark GBT Issue #2700
base: master
Are you sure you want to change the base?
Fix PySpark GBT Issue #2700
Conversation
@slundberg @thatlittleboy @CloseChoice Could anyone merge this PR? Thanks! |
@venser12 Thanks for following this up. Could you please add a test for this? I will take a closer look later this week. |
shap/explainers/_tree.py
Outdated
@@ -1196,7 +1196,7 @@ def buildTree(index, node): | |||
self.values[index] = [node.prediction()] #prediction for the node | |||
else: | |||
self.values[index] = [e for e in node.impurityStats().stats()] #for gini: NDarray(numLabel): 1 per label: number of item for each label which went through this node | |||
self.node_sample_weight[index] = node.impurityStats().count() #weighted count of element trough this node | |||
self.node_sample_weight[index] = sum([e for e in node.impurityStats().stats()]) #weighted count of element trough this node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain why this is the correct way to do things? I am not deeply familiar with this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cam find the answer above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is because
node.impurityStats().count()
rounds the values in PySpark models, so it return zero values when it was not zero values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can please merge this PR. Thanks.
@CloseChoice, you can find the explanation at the top. Here tou have a test: from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import random
# Create Spark Session
spark = SparkSession.builder.appName("Shap").getOrCreate()
# Create DataFrame
data = [(random.randint(1, 100), random.randint(1, 50), random.randint(0, 1)) for _ in range(10)]
# Create Spark DataFrame
df = spark.createDataFrame(data, ["numeric", "numeric_2", "label"])
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
# Assambling columns
assembler = VectorAssembler(inputCols=["numeric", "numeric_2"], outputCol='features')
# GBTClassifier
gbt_classifier = GBTClassifier(featuresCol="features", labelCol="label")
# Pipeline for each step
pipeline = Pipeline(stages=[assembler, gbt_classifier])
# Train
pipeline = pipeline.fit(df)
model = pipeline.stages[-1]
model.save("path_to_model")
import shap
import numpy as np
from pyspark.ml.classification import GBTClassificationModel
loaded_model = GBTClassificationModel.load("path_to_model")
explainer=shap.Explainer(loaded_model)
explainer.shap_values(np.array(df.select("numeric","numeric_2").collect()[0])) |
@CloseChoice , I don't know why this branch has conflicts. We only have changed one line. I tried to merged it manually and it worked: |
Why do you use To fix the conflicts please do:
Edit: |
@CloseChoice , I used Moreover, SHAP only supports Exception: The passed model is not callable and cannot be analyzed directly with the given masker! Model: GBTClassifier_636da3917b9d So with # GBTClassificationModel object
model = pipeline.stages[-1]
# Save the GBTClassificationModel
model.save("path_to_model")
import shap
import numpy as np
from pyspark.ml import GBTClassificationModel
# Load the GBTClassificationModel
loaded_model = GBTClassificationModel.load("path_to_model") For no reason I get zero values when it is clearly not zero values, so we must change this in _tree.py I still think that changing that line will fix the error. I also opened a new PR with no conflicts #3384 Thanks! |
Fix PySpark GBT Issue [fix #884 , fix #2480 ]
ISSUE:
feature_perturbation="interventional"
as suggested does not work with pyspark models as thepredict
function is not implemented for pyspark modelsfeature_perturbation="tree_path_dependent"
is failing due to a check in the code that is meant to ensure that background dataset lands in every leaf.This should by right pass in all cases, considering that no background dataset is required for
feature_perturbation="tree_path_dependent"
.In some cases for pyspark gbt models,
fully_defined_weighting
is incorrectly set to False.fully_defined_weighting
is determined by the values ofnode_sample_weight
, which is determined by the code below:node.impurityStats()
returns aGiniCalculator
, and the method.count()
should return afloat
instead ofint
.See source
node.impurityStats().count()
,you will notice that the values has been rounded down to
int
.node.impurityStats().count()
should return the same values assum([e for e in node.impurityStats().stats()])
if you follow the image above. It is however rounding down the values, and in some cases values greater than 0 and less than 1 are rounded down to 0.self.fully_defined_weighting
to returnFalse
, even when the values are clearly not zero.SOLUTION:
node.impurityStats().count()
. Replace withsum([e for e in node.impurityStats().stats()])
which does exactly the same, but retain the value as float.