Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PySpark loaded models #3384

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

mriomoreno
Copy link

@mriomoreno mriomoreno commented Nov 13, 2023

Overview

fix #884 , fix #2480, fix #3383

ISSUE:

Running PySpark GBT models sometimes causes the shap package to fail with the error message:

The background dataset you provided does not cover all the leaves in the model, so TreeExplainer cannot run with the feature_perturbation="tree_path_dependent" option! " Try providing a larger background dataset, no background dataset, or using feature_perturbation="interventional"."

Using feature_perturbation="interventional" as suggested does not work with pyspark models as the predict function is not implemented for pyspark models

# lines 1082 to 1085 in shap/shap/explainers/_tree.py
if self.model_type == "pyspark":
  #import pyspark
  # TODO: support predict for pyspark
   raise NotImplementedError("Predict with pyspark isn't implemented. Don't run 'interventional' as feature_perturbation.")

The feature_perturbation="tree_path_dependent" is failing due to a check in the code that is meant to ensure that background dataset lands in every leaf.

# lines 1031 to 1033 in shap/shap/explainers/_tree.py
# ensure that the passed background dataset lands in every leaf
if np.min(self.trees[i].node_sample_weight) <= 0:
    self.fully_defined_weighting = False

This should by right pass in all cases, considering that no background dataset is required for feature_perturbation="tree_path_dependent".

In some cases for pyspark gbt models, fully_defined_weighting is incorrectly set to False. fully_defined_weighting is determined by the values of node_sample_weight, which is determined by the code below:

# line 1199 in shap/shap/explainers/_tree.py
self.node_sample_weight[index] = node.impurityStats().count() #weighted count of element trough this node

node.impurityStats() returns a GiniCalculator, and the method .count() should return a float instead of int.
See source

image

However, if you create a pyspark GBT model and obtain the values for node.impurityStats().count(),
you will notice that the values has been rounded down to int.

node.impurityStats().count() should return the same values as sum([e for e in node.impurityStats().stats()]) if you follow the image above. It is however rounding down the values, and in some cases values greater than 0 and less than 1 are rounded down to 0.

This causes the self.fully_defined_weighting to return False, even when the values are clearly not zero.

SOLUTION

Avoid using node.impurityStats().count(). Replace with sum([e for e in node.impurityStats().stats()]) which does exactly the same, but retain the value as float.

@mriomoreno
Copy link
Author

@CloseChoice I created another PR with no conflicts, if you plase can merge it. Thanks!

@mriomoreno
Copy link
Author

@connortann , could we merge this PR? The solution is pretty easy.

Thanks

@mriomoreno
Copy link
Author

mriomoreno commented Nov 14, 2023

This PR is a new version of the old PR: #2700

@connortann connortann added the bug Indicates an unexpected problem or unintended behaviour label Nov 14, 2023
@connortann
Copy link
Collaborator

Thanks for the PR! I think there are two things we need to fix.

Firstly, it looks like this change breaks an existing test: tests/explainers/test_tree.py::test_pyspark_regression_decision_tree.

Secondly, we need to create a regression test for the issue that this PR fixes. I tried running a unit test based on your example in the other thread: #2700 (comment)

Howeever, this doesn't reproduce the issue: the test actually passes on master. Here's the test I tried:

# test_tree.py, a new test (probably to be put around line 286)

def test_pyspark_loaded_gbt(configure_pyspark_python, tmp_path, random_seed):
    pytest.importorskip("pyspark")
    pytest.importorskip("pyspark.ml")

    from pyspark.ml import Pipeline
    from pyspark.ml.classification import GBTClassificationModel, GBTClassifier
    from pyspark.ml.feature import VectorAssembler
    from pyspark.sql import SparkSession


    rs = np.random.RandomState(seed=random_seed)

    # Create Spark Session
    spark = SparkSession.builder.appName("Shap").getOrCreate()

    # Create DataFrame
    data = [(rs.randint(1, 100), rs.randint(1, 50), rs.randint(0, 1)) for _ in range(10)]

    # Create Spark DataFrame
    df = spark.createDataFrame(data, ["numeric", "numeric_2", "label"])

    # Assambling columns
    assembler = VectorAssembler(inputCols=["numeric", "numeric_2"], outputCol='features')

    # GBTClassifier
    gbt_classifier = GBTClassifier(featuresCol="features", labelCol="label")

    # Pipeline for each step
    pipeline = Pipeline(stages=[assembler, gbt_classifier])

    # Train
    pipeline = pipeline.fit(df)
    model = pipeline.stages[-1]
    model_path = str(tmp_path / "model")
    model.save(model_path)

    loaded_model = GBTClassificationModel.load(model_path)

    explainer=shap.Explainer(loaded_model)
    explainer.shap_values(np.array(df.select("numeric","numeric_2").collect()[0]))

@venser12
Copy link

venser12 commented Nov 14, 2023

@connortann which version of pyspark do you have?

@connortann
Copy link
Collaborator

I tested locally with 3.4.0. I note the CI tests are using 3.5.0.

@mriomoreno
Copy link
Author

mriomoreno commented Nov 15, 2023

@connortann can you show me the output of this?

pytest.importorskip("pyspark")
pytest.importorskip("pyspark.ml")

from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassificationModel, GBTClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession

model_path = str(tmp_path / "model")

loaded_model = GBTClassificationModel.load(model_path)
explainer=shap.Explainer(loaded_model)

for i in range(20):
 print(explainer.model.trees[i].node_sample_weight)

I get zero values. Maybe you are not getting zero values

image

I have 2.4.7 pyspark version and 0.37.0 shap version, it is posibble that zero values error is fixed in new versions

@connortann
Copy link
Collaborator

connortann commented Nov 15, 2023

I do not get zero values. Here's the full test I ran, including generating the data with a fixed random seed:

def test_pyspark_loaded_gbt(configure_pyspark_python, tmp_path):
    pytest.importorskip("pyspark")
    pytest.importorskip("pyspark.ml")

    from pyspark.ml import Pipeline
    from pyspark.ml.classification import GBTClassificationModel, GBTClassifier
    from pyspark.ml.feature import VectorAssembler
    from pyspark.sql import SparkSession

    spark = SparkSession.builder.appName("Shap").getOrCreate()

    # Create DataFrame
    rs = np.random.RandomState(seed=0)
    data = [(rs.randint(1, 100), rs.randint(1, 50), rs.randint(0, 1)) for _ in range(10)]
    df = spark.createDataFrame(data, ["numeric", "numeric_2", "label"])

    # Train model
    assembler = VectorAssembler(inputCols=["numeric", "numeric_2"], outputCol='features')
    gbt_classifier = GBTClassifier(featuresCol="features", labelCol="label")
    pipeline = Pipeline(stages=[assembler, gbt_classifier])
    pipeline = pipeline.fit(df)
    model = pipeline.stages[-1]
    
    # Save and reload
    model_path = str(tmp_path / "model")
    model.save(model_path)
    loaded_model = GBTClassificationModel.load(model_path)
    explainer=shap.Explainer(loaded_model)

    for i in range(20):
        print(explainer.model.trees[i].node_sample_weight)

    explainer.shap_values(np.array(df.select("numeric","numeric_2").collect()[0]))
    assert False  # So output is displayed

Here's the output:

[10.]
[10.  8.  2.]
[10.  7.  3.]
[10.]
[10.]
[10.]
[10.]
[10.  2.  8.]
[10.]
[10.]
[10.  2.  8.]
[10.  2.  8.]
[10.  9.  1.]
[10.  1.  9.]
[10.  7.  3.]
[10.]
[10.]
[10.  4.  6.]
[10.  3.  7.]
[10.  3.  7.]

It's important that the shap code works with recent versions of pyspark, so the CI failures on this PR would need to be fixed before we can consider merging a patch. Please also include the regression test above in this PR.

@mriomoreno
Copy link
Author

mriomoreno commented Nov 15, 2023

@connortann thanks!, I think the error is in old pyspark version.

Just to make sure the error is this, can you show me the next 3 outputs? Im sorry that i can not upgrade pyspark version for testing it on my own :"(

[e for e in loaded_model.trees[12]._call_java("rootNode").rightChild().impurityStats().stats()]
[e for e in loaded_model.trees[12]._call_java("rootNode").leftChild().impurityStats().stats()]
[e for e in loaded_model.trees[12]._call_java("rootNode").impurityStats().stats()]

So we can see if in your 12-th tree 0, values are not rounded to zero. It could be that you get 1, values so the error may persist.

Thanks!

@connortann
Copy link
Collaborator

A good way to check those outputs is to include the test in your branch and push your changes. The tests will run on CI and you will be able to inspect the output of any print statements there.

In any case here are those three lists when I run locally:

[1.0, -0.2392929041640417, 0.05726109398326118]
[9.0, -2.153636137476374, 0.5153498458493506]
[10.0, -2.3929290416404156, 0.5726109398326118]

@mriomoreno
Copy link
Author

mriomoreno commented Nov 15, 2023

@connortann thanks, finally the error might come from no recent pyspark version:

You always get the number of samples in each node, insted of the sum.

How can I run the test in CI in my own?

@connortann
Copy link
Collaborator

connortann commented Nov 15, 2023

Just push more commits to the branch associated with this PR (which is your fork mriomoreno:master), and you will see the tests run on GitHub Actions. The most recent runs will be shown below:

image

@connortann connortann added the awaiting feedback Indicates that further information is required from the issue creator label Nov 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting feedback Indicates that further information is required from the issue creator bug Indicates an unexpected problem or unintended behaviour
Projects
None yet
3 participants