Skip to content

Commit

Permalink
fix: Possible multithreading issue when two scores may come in parall…
Browse files Browse the repository at this point in the history
…el they may not safely fill pointer values (microsoft#799)

* Move transient var to transient lazy val

* Try guaranteeing NativeLibrary is loaded everytime it is needed but only loaded once

* Forgot some lightGBMlibConstants

* Add freeing of resources, handle all memory in BoosterHandler

* Remove serialization from BoosterHandler, some documentation and minor changes. By boosterHandler not being serializable, it forces Serializable objects to declare it as transient

* Style forbigs implementing finalize method, Java does not guarantee order of finalization so it is dangerous. I wonder how is it guaranteed that the memory will be cleared though

* Add finalize method to call free for Native C++ allocated memory

* predict normal correct constant

* Change implementation to a ThreadLocal based synchronization
  • Loading branch information
JoanFM authored and ocworld committed Mar 24, 2020
1 parent d65a86b commit 350f0b0
Showing 1 changed file with 103 additions and 54 deletions.
157 changes: 103 additions & 54 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala
Expand Up @@ -9,97 +9,143 @@ import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.sql.{SaveMode, SparkSession}

//scalastyle:off
protected abstract class NativePtrHandler[T](val ptr: T) {
protected def freeNativePtr(): Unit
override def finalize(): Unit = {
if (ptr != null) {
freeNativePtr()
}
}
}

protected class DoubleNativePtrHandler(ptr: SWIGTYPE_p_double) extends NativePtrHandler[SWIGTYPE_p_double](ptr) {
override protected def freeNativePtr(): Unit = {
lightgbmlib.delete_doubleArray(ptr)
}
}

protected class LongLongNativePtrHandler(ptr: SWIGTYPE_p_long_long) extends NativePtrHandler[SWIGTYPE_p_long_long](ptr) {
override protected def freeNativePtr(): Unit = {
lightgbmlib.delete_int64_tp(ptr)
}
}

/** Wraps the boosterPtr and guarantees that Native library is initialized
* everytime it is needed
* @param model The string serialized representation of the learner
*/
protected class BoosterHandler(model: String) {

LightGBMUtils.initializeNativeLibrary()

var boosterPtr: SWIGTYPE_p_void = {
getBoosterPtrFromModelString(model)
}

var scoredDataOutPtr: SWIGTYPE_p_double = {
lightgbmlib.new_doubleArray(numClasses)
val scoredDataOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
override def initialValue(): DoubleNativePtrHandler = {
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numClasses))
}
}
}

val scoredDataLengthLongPtr: ThreadLocal[LongLongNativePtrHandler] = {
new ThreadLocal[LongLongNativePtrHandler] {
override def initialValue(): LongLongNativePtrHandler = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, 1)
new LongLongNativePtrHandler(dataLongLengthPtr)
}
}
}

val leafIndexDataOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
override def initialValue(): DoubleNativePtrHandler = {
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numTotalModel))
}
}
}

var scoredDataLengthLongPtr: SWIGTYPE_p_long_long = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, 1)
dataLongLengthPtr
val leafIndexDataLengthLongPtr: ThreadLocal[LongLongNativePtrHandler] = {
new ThreadLocal[LongLongNativePtrHandler] {
override def initialValue(): LongLongNativePtrHandler = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, numTotalModel)
new LongLongNativePtrHandler(dataLongLengthPtr)
}
}
}

var leafIndexDataOutPtr: SWIGTYPE_p_double = lightgbmlib.new_doubleArray(numTotalModel)
val featureImportanceOutPtr: ThreadLocal[DoubleNativePtrHandler] = {
new ThreadLocal[DoubleNativePtrHandler] {
override def initialValue(): DoubleNativePtrHandler = {
new DoubleNativePtrHandler(lightgbmlib.new_doubleArray(numFeatures))
}
}
}

var leafIndexDataLengthLongPtr: SWIGTYPE_p_long_long = {
val dataLongLengthPtr = lightgbmlib.new_int64_tp()
lightgbmlib.int64_tp_assign(dataLongLengthPtr, numTotalModel)
dataLongLengthPtr
val dumpModelOutPtr: ThreadLocal[LongLongNativePtrHandler] = {
new ThreadLocal[LongLongNativePtrHandler] {
override def initialValue(): LongLongNativePtrHandler = {
new LongLongNativePtrHandler(lightgbmlib.new_int64_tp())
}
}
}

lazy val numClasses = getNumClasses
lazy val numFeatures = getNumFeatures
lazy val numTotalModel = getNumTotalModel
lazy val numTotalModelPerIteration = getNumModelPerIteration
lazy val numClasses: Int = getNumClasses
lazy val numFeatures: Int = getNumFeatures
lazy val numTotalModel: Int = getNumTotalModel
lazy val numTotalModelPerIteration: Int = getNumModelPerIteration

lazy val rawScoreConstant = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE
lazy val normalScoreConstant = lightgbmlibConstants.C_API_PREDICT_NORMAL
lazy val leafIndexPredictConstant = lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX
lazy val rawScoreConstant: Int = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE
lazy val normalScoreConstant: Int = lightgbmlibConstants.C_API_PREDICT_NORMAL
lazy val leafIndexPredictConstant: Int = lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX

lazy val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
lazy val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
lazy val dataInt32bitType: Int = lightgbmlibConstants.C_API_DTYPE_INT32
lazy val data64bitType: Int = lightgbmlibConstants.C_API_DTYPE_FLOAT64

private def getNumClasses: Int = {
val numClassesOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterGetNumClasses(boosterPtr, numClassesOut),
"Booster NumClasses")
lightgbmlib.intp_value(numClassesOut)
val out = lightgbmlib.intp_value(numClassesOut)
lightgbmlib.delete_intp(numClassesOut)
out
}

private def getNumModelPerIteration: Int = {
val numModelPerIterationOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterNumModelPerIteration(boosterPtr, numModelPerIterationOut),
"Booster models per iteration")
lightgbmlib.intp_value(numModelPerIterationOut)
val out = lightgbmlib.intp_value(numModelPerIterationOut)
lightgbmlib.delete_intp(numModelPerIterationOut)
out
}

private def getNumTotalModel: Int = {
val numModelOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterNumberOfTotalModel(boosterPtr, numModelOut),
"Booster total models")
lightgbmlib.intp_value(numModelOut)
val out = lightgbmlib.intp_value(numModelOut)
lightgbmlib.delete_intp(numModelOut)
out
}

private def getNumFeatures: Int = {
val numFeaturesOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterGetNumFeature(boosterPtr, numFeaturesOut),
"Booster NumFeature")
lightgbmlib.intp_value(numFeaturesOut)
val out = lightgbmlib.intp_value(numFeaturesOut)
lightgbmlib.delete_intp(numFeaturesOut)
out
}

private def freeNativeMemory(): Unit = {
if (scoredDataLengthLongPtr != null) {
lightgbmlib.delete_int64_tp(scoredDataLengthLongPtr)
scoredDataLengthLongPtr = null
}
if (scoredDataOutPtr != null) {
lightgbmlib.delete_doubleArray(scoredDataOutPtr)
scoredDataOutPtr = null
}
if (leafIndexDataLengthLongPtr != null) {
lightgbmlib.delete_int64_tp(leafIndexDataLengthLongPtr)
leafIndexDataLengthLongPtr = null
}
if (leafIndexDataOutPtr != null) {
lightgbmlib.delete_doubleArray(leafIndexDataOutPtr)
leafIndexDataOutPtr = null
}
if (boosterPtr != null) {
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterFree(boosterPtr), "Finalize Booster")
boosterPtr = null
Expand All @@ -108,10 +154,10 @@ protected class BoosterHandler(model: String) {

override protected def finalize(): Unit = {
freeNativeMemory()
super.finalize()
}
}

//scalastyle:on
/** Represents a LightGBM Booster learner
* @param model The string serialized representation of the learner
*/
Expand Down Expand Up @@ -164,9 +210,9 @@ class LightGBMBooster(val model: String) extends Serializable {
sparseVector.numNonzeros,
boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
kind, -1, datasetParams,
boosterHandler.scoredDataLengthLongPtr, boosterHandler.scoredDataOutPtr), "Booster Predict")
boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr), "Booster Predict")

predScoreToArray(classification, boosterHandler.scoredDataOutPtr, kind)
predScoreToArray(classification, boosterHandler.scoredDataOutPtr.get().ptr, kind)
}

protected def predictScoreForMat(row: Array[Double], kind: Int, classification: Boolean): Array[Double] = {
Expand All @@ -182,9 +228,9 @@ class LightGBMBooster(val model: String) extends Serializable {
row, boosterHandler.boosterPtr, data64bitType,
numCols,
isRowMajor, kind,
-1, datasetParams, boosterHandler.scoredDataLengthLongPtr, boosterHandler.scoredDataOutPtr),
-1, datasetParams, boosterHandler.scoredDataLengthLongPtr.get().ptr, boosterHandler.scoredDataOutPtr.get().ptr),
"Booster Predict")
predScoreToArray(classification, boosterHandler.scoredDataOutPtr, kind)
predScoreToArray(classification, boosterHandler.scoredDataOutPtr.get().ptr, kind)
}

protected def predictLeafForCSR(sparseVector: SparseVector): Array[Double] = {
Expand All @@ -200,9 +246,10 @@ class LightGBMBooster(val model: String) extends Serializable {
sparseVector.numNonzeros,
boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
boosterHandler.leafIndexPredictConstant, -1, datasetParams,
boosterHandler.leafIndexDataLengthLongPtr, boosterHandler.leafIndexDataOutPtr), "Booster Predict Leaf")
boosterHandler.leafIndexDataLengthLongPtr.get().ptr,
boosterHandler.leafIndexDataOutPtr.get().ptr), "Booster Predict Leaf")

predLeafToArray(boosterHandler.leafIndexDataOutPtr)
predLeafToArray(boosterHandler.leafIndexDataOutPtr.get().ptr)
}

protected def predictLeafForMat(row: Array[Double]): Array[Double] = {
Expand All @@ -218,10 +265,11 @@ class LightGBMBooster(val model: String) extends Serializable {
row, boosterHandler.boosterPtr, data64bitType,
numCols,
isRowMajor, boosterHandler.leafIndexPredictConstant,
-1, datasetParams, boosterHandler.leafIndexDataLengthLongPtr, boosterHandler.leafIndexDataOutPtr),
-1, datasetParams, boosterHandler.leafIndexDataLengthLongPtr.get().ptr,
boosterHandler.leafIndexDataOutPtr.get().ptr),
"Booster Predict Leaf")

predLeafToArray(boosterHandler.leafIndexDataOutPtr)
predLeafToArray(boosterHandler.leafIndexDataOutPtr.get().ptr)
}

def saveNativeModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = {
Expand All @@ -236,7 +284,8 @@ class LightGBMBooster(val model: String) extends Serializable {
}

def dumpModel(session: SparkSession, filename: String, overwrite: Boolean): Unit = {
val json = lightgbmlib.LGBM_BoosterDumpModelSWIG(boosterHandler.boosterPtr, 0, 0, 1, lightgbmlib.new_int64_tp())
val json = lightgbmlib.LGBM_BoosterDumpModelSWIG(boosterHandler.boosterPtr, 0, 0, 1,
boosterHandler.dumpModelOutPtr.get().ptr)
val rdd = session.sparkContext.parallelize(Seq(json))
import session.sqlContext.implicits._
val dataset = session.sqlContext.createDataset(rdd)
Expand All @@ -251,11 +300,11 @@ class LightGBMBooster(val model: String) extends Serializable {
*/
def getFeatureImportances(importanceType: String): Array[Double] = {
val importanceTypeNum = if (importanceType.toLowerCase.trim == "gain") 1 else 0
val featureImportances = lightgbmlib.new_doubleArray(numFeatures)
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterFeatureImportance(boosterHandler.boosterPtr, -1, importanceTypeNum, featureImportances),
lightgbmlib.LGBM_BoosterFeatureImportance(boosterHandler.boosterPtr, -1,
importanceTypeNum, boosterHandler.featureImportanceOutPtr.get().ptr),
"Booster FeatureImportance")
(0 until numFeatures).map(lightgbmlib.doubleArray_getitem(featureImportances, _)).toArray
(0 until numFeatures).map(lightgbmlib.doubleArray_getitem(boosterHandler.featureImportanceOutPtr.get().ptr, _)).toArray
}

private def predScoreToArray(classification: Boolean, scoredDataOutPtr: SWIGTYPE_p_double,
Expand Down

0 comments on commit 350f0b0

Please sign in to comment.