Skip to content

Commit

Permalink
feat: Add the option to get Feature Contributions in LightGBMBooster …
Browse files Browse the repository at this point in the history
…used by LightGBMRanker (#791)

* Allow LightGBMRanker to compute features shap

* Take featureShapGetter into trait that potentially can be used by other models

* Fix data used to be the one for shap and add tests for getShapFeatures in LightGBMRanker

* Fix issues with merge conflict resolution

* Refactor to share predictForMat and predictForCSR from Score, Shap and LeafIndex usage

* Fix compilation issue
  • Loading branch information
JoanFM committed Feb 11, 2020
1 parent 875f89d commit f702921
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 55 deletions.
104 changes: 54 additions & 50 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala
Expand Up @@ -76,6 +76,24 @@ protected class BoosterHandler(model: String) {
}
}
}

val shapDataOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
override def initialValue(): DoubleNativePtrHandler = {
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numFeatures))
}
}
}

val shapDataLengthLongPtr: ThreadLocal[LongLongNativePtrHandler] = {
new ThreadLocal[LongLongNativePtrHandler] {
override def initialValue(): LongLongNativePtrHandler = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, numFeatures)
new LongLongNativePtrHandler(dataLongLengthPtr)
}
}
}

val featureImportanceOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
Expand All @@ -101,6 +119,7 @@ protected class BoosterHandler(model: String) {
lazy val rawScoreConstant: Int = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE
lazy val normalScoreConstant: Int = lightgbmlibConstants.C_API_PREDICT_NORMAL
lazy val leafIndexPredictConstant: Int = lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX
lazy val contribPredictConstant = lightgbmlibConstants.C_API_PREDICT_CONTRIB

lazy val dataInt32bitType: Int = lightgbmlibConstants.C_API_DTYPE_INT32
lazy val data64bitType: Int = lightgbmlibConstants.C_API_DTYPE_FLOAT64
Expand Down Expand Up @@ -175,16 +194,34 @@ class LightGBMBooster(val model: String) extends Serializable {
if (raw) boosterHandler.rawScoreConstant
else boosterHandler.normalScoreConstant
features match {
case dense: DenseVector => predictScoreForMat(dense.toArray, kind, classification)
case sparse: SparseVector => predictScoreForCSR(sparse, kind, classification)
case dense: DenseVector => predictForMat(dense.toArray, kind,
boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr)
case sparse: SparseVector => predictForCSR(sparse, kind,
boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr)
}
predScoreToArray(classification, boosterHandler.scoredDataOutPtr.get().ptr, kind)
}

def predictLeaf(features: Vector): Array[Double] = {
val kind = boosterHandler.leafIndexPredictConstant
features match {
case dense: DenseVector => predictLeafForMat(dense.toArray)
case sparse: SparseVector => predictLeafForCSR(sparse)
case dense: DenseVector => predictForMat(dense.toArray, kind,
boosterHandler.leafIndexDataLengthLongPtr.get().ptr, boosterHandler.leafIndexDataOutPtr.get().ptr)
case sparse: SparseVector => predictForCSR(sparse, kind,
boosterHandler.leafIndexDataLengthLongPtr.get().ptr, boosterHandler.leafIndexDataOutPtr.get().ptr)
}
predLeafToArray(boosterHandler.leafIndexDataOutPtr.get().ptr)
}

def featuresShap(features: Vector): Array[Double] = {
val kind = boosterHandler.contribPredictConstant
features match {
case dense: DenseVector => predictForMat(dense.toArray, kind,
boosterHandler.shapDataLengthLongPtr.get().ptr, boosterHandler.shapDataOutPtr.get().ptr)
case sparse: SparseVector => predictForCSR(sparse, kind,
boosterHandler.shapDataLengthLongPtr.get().ptr, boosterHandler.shapDataOutPtr.get().ptr)
}
shapToArray(boosterHandler.shapDataOutPtr.get().ptr)
}

lazy val numClasses: Int = boosterHandler.numClasses
Expand All @@ -197,7 +234,9 @@ class LightGBMBooster(val model: String) extends Serializable {

lazy val numIterations: Int = numTotalModel / numModelPerIteration

protected def predictScoreForCSR(sparseVector: SparseVector, kind: Int, classification: Boolean): Array[Double] = {
protected def predictForCSR(sparseVector: SparseVector, kind: Int,
dataLengthLongPtr: SWIGTYPE_p_long_long,
dataOutPtr: SWIGTYPE_p_double): Unit = {
val numCols = sparseVector.size

val datasetParams = "max_bin=255"
Expand All @@ -210,12 +249,12 @@ class LightGBMBooster(val model: String) extends Serializable {
sparseVector.numNonzeros,
boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
kind, -1, datasetParams,
boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr), "Booster Predict")

predScoreToArray(classification, boosterHandler.scoredDataOutPtr.get().ptr, kind)
dataLengthLongPtr, dataOutPtr), "Booster Predict")
}

protected def predictScoreForMat(row: Array[Double], kind: Int, classification: Boolean): Array[Double] = {
protected def predictForMat(row: Array[Double], kind: Int,
dataLengthLongPtr: SWIGTYPE_p_long_long,
dataOutPtr: SWIGTYPE_p_double): Unit = {
val data64bitType = boosterHandler.data64bitType

val numCols = row.length
Expand All @@ -228,48 +267,8 @@ class LightGBMBooster(val model: String) extends Serializable {
row, boosterHandler.boosterPtr, data64bitType,
numCols,
isRowMajor, kind,
-1, datasetParams, boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr),
-1, datasetParams, dataLengthLongPtr, dataOutPtr),
"Booster Predict")
predScoreToArray(classification, boosterHandler.scoredDataOutPtr.get().ptr, kind)
}

protected def predictLeafForCSR(sparseVector: SparseVector): Array[Double] = {
val numCols = sparseVector.size

val datasetParams = "max_bin=255 is_pre_partition=True"
val dataInt32bitType = boosterHandler.dataInt32bitType
val data64bitType = boosterHandler.data64bitType

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForCSRSingle(
sparseVector.indices, sparseVector.values,
sparseVector.numNonzeros,
boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
boosterHandler.leafIndexPredictConstant, -1, datasetParams,
boosterHandler.leafIndexDataLengthLongPtr.get().ptr,
boosterHandler.leafIndexDataOutPtr.get().ptr), "Booster Predict Leaf")

predLeafToArray(boosterHandler.leafIndexDataOutPtr.get().ptr)
}

protected def predictLeafForMat(row: Array[Double]): Array[Double] = {
val data64bitType = boosterHandler.data64bitType

val numCols = row.length
val isRowMajor = 1

val datasetParams = "max_bin=255"

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForMatSingle(
row, boosterHandler.boosterPtr, data64bitType,
numCols,
isRowMajor, boosterHandler.leafIndexPredictConstant,
-1, datasetParams, boosterHandler.leafIndexDataLengthLongPtr.get().ptr,
boosterHandler.leafIndexDataOutPtr.get().ptr),
"Booster Predict Leaf")

predLeafToArray(boosterHandler.leafIndexDataOutPtr.get().ptr)
}

def saveNativeModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = {
Expand Down Expand Up @@ -329,4 +328,9 @@ class LightGBMBooster(val model: String) extends Serializable {
(0 until numTotalModel).map(modelNum =>
lightgbmlib.doubleArray_getitem(leafIndexDataOutPtr, modelNum)).toArray
}

private def shapToArray(shapDataOutPtr: SWIGTYPE_p_double): Array[Double] = {
(0 until numFeatures).map(featNum =>
lightgbmlib.doubleArray_getitem(shapDataOutPtr, featNum)).toArray
}
}
Expand Up @@ -80,11 +80,20 @@ class LightGBMRanker(override val uid: String)
override def copy(extra: ParamMap): LightGBMRanker = defaultCopy(extra)
}

trait HasFeatureShapGetters {
val model: LightGBMBooster

def getFeatureShaps(features: Vector): Array[Double] = {
model.featuresShap(features)
}
}

/** Model produced by [[LightGBMRanker]]. */
@InternalWrapper
class LightGBMRankerModel(override val uid: String, model: LightGBMBooster, labelColName: String,
class LightGBMRankerModel(override val uid: String, override val model: LightGBMBooster, labelColName: String,
featuresColName: String, predictionColName: String)
extends RankerModel[Vector, LightGBMRankerModel]
with HasFeatureShapGetters with HasFeatureImportanceGetters
with ConstructorWritable[LightGBMRankerModel] {

// Update the underlying Spark ML com.microsoft.ml.spark.core.serialize.params
Expand All @@ -106,15 +115,13 @@ class LightGBMRankerModel(override val uid: String, model: LightGBMBooster, labe
override def objectsToSave: List[Any] =
List(uid, model, getLabelCol, getFeaturesCol, getPredictionCol)

override def numFeatures: Int = model.numFeatures

def saveNativeModel(filename: String, overwrite: Boolean): Unit = {
val session = SparkSession.builder().getOrCreate()
model.saveNativeModel(session, filename, overwrite)
}

def getFeatureImportances(importanceType: String): Array[Double] = {
model.getFeatureImportances(importanceType)
}

def getModel: LightGBMBooster = this.model
}

Expand Down
Expand Up @@ -90,6 +90,12 @@ trait LightGBMTestUtils extends TestBase {
assert(splitLength == gainLength && splitLength == featuresLength)
}

def assertFeatureShapLengths(fitModel: Model[_] with HasFeatureShapGetters, features: Vector, df: DataFrame): Unit = {
val shapLength = fitModel.getFeatureShaps(features).length
val featuresLength = df.select(featuresCol).first().getAs[Vector](featuresCol).size
assert(shapLength == featuresLength)
}

lazy val numPartitions = 2
val startingPortIndex = 0
private var portIndex = startingPortIndex
Expand Down
Expand Up @@ -10,6 +10,7 @@ import com.microsoft.ml.spark.lightgbm.split1.LightGBMTestUtils
import com.microsoft.ml.spark.lightgbm.{LightGBMRanker, LightGBMRankerModel, LightGBMUtils, TrainUtils}
import org.apache.spark.SparkException
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, _}
Expand Down Expand Up @@ -104,6 +105,28 @@ class VerifyLightGBMRanker extends Benchmarks with EstimatorFuzzing[LightGBMRank
df.withColumn(queryCol, concat(lit("str_"), col(queryCol))))
}

test("Verify LightGBM Ranker feature shaps") {
val baseDF = Seq(
(0L, 1, 1.2, 2.3),
(0L, 0, 3.2, 2.35),
(1L, 0, 1.72, 1.39),
(1L, 1, 1.82, 3.8)
).toDF(queryCol, labelCol, "f1", "f2")

val df = new VectorAssembler()
.setInputCols(Array("f1", "f2"))
.setOutputCol(featuresCol)
.transform(baseDF)
.select(queryCol, labelCol, featuresCol)

val fitModel = baseModel.setEvalAt(1 to 3 toArray).fit(df)

val featuresInput = Vectors.dense(Array[Double](0.0, 0.0))
assert(fitModel.numFeatures == 2)
assertFeatureShapLengths(fitModel, featuresInput, df)
assert(fitModel.predict(featuresInput) == fitModel.getFeatureShaps(featuresInput).sum)
}

test("verify cardinality counts: int") {
val counts = TrainUtils.countCardinality(Seq(1, 1, 2, 2, 2, 3))

Expand Down

0 comments on commit f702921

Please sign in to comment.