Skip to content

Commit

Permalink
feat: add interface function for updating learning_rate per each iter…
Browse files Browse the repository at this point in the history
…ation in LightGBMDelegate (microsoft#849)

* feat: add update learning_rate by using LightGBMDelegate

* feat: add update learning_rate by using LightGBMDelegate

* feat: add update learning_rate by using LightGBMDelegate

* feat: add update learning_rate by using LightGBMDelegate

* fix minor

* fix serialization error

* fix serialization error

* change LightGBMDelegate to trait for scala style

* change LightGBMDelegate to trait for scala style

* change LightGBMDelegate to trait for scala style
  • Loading branch information
Keunhyun Oh authored and ocworld committed Apr 8, 2020
1 parent 8143654 commit b8c527f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@ package com.microsoft.ml.spark.lightgbm
import com.microsoft.ml.lightgbm.SWIGTYPE_p_void
import org.slf4j.Logger

abstract class LightGBMDelegate extends Serializable {
trait LightGBMDelegate extends Serializable {
def beforeTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = {
// override this function and write code
}

def afterTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit
validEvalResults: Option[Map[String, Double]]): Unit = {
// override this function and write code
}

def getLearningRate(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
previousLearningRate: Double): Double = {
// override this function and write code
previousLearningRate
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,19 @@ private object TrainUtils extends Serializable {
val bestIter = new Array[Int](evalCounts)
val delegate = trainParams.delegate
val partitionId = TaskContext.getPartitionId
var learningRate: Double = trainParams.learningRate
while (!isFinished && iters < trainParams.numIterations) {

if (delegate.isDefined) {
delegate.get.beforeTrainIteration(partitionId, iters, log, trainParams, boosterPtr, hasValid)
val newLearningRate = delegate.get.getLearningRate(partitionId, iters, log, trainParams, learningRate)
if (newLearningRate != learningRate) {
log.info(s"LightGBM worker calling LGBM_BoosterResetParameter to reset learningRate" +
s" (newLearningRate: $newLearningRate)")
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterResetParameter(boosterPtr.get,
s"learning_rate=$newLearningRate"), "Booster Reset learning_rate Param")
learningRate = newLearningRate
}
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.ml.spark.lightgbm.split1
import java.io.File
import java.nio.file.{Files, Path, Paths}

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void
import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.benchmarks.{Benchmarks, DatasetUtils}
import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject}
Expand All @@ -23,6 +24,33 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.slf4j.Logger

@SerialVersionUID(100L)
class TrainDelegate extends LightGBMDelegate {

override def beforeTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = {
// nothing
}

override def afterTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit = {
// nothing
}

override def getLearningRate(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
previousLearningRate: Double): Double = {
if (curIters == 0) {
previousLearningRate
} else {
previousLearningRate * 0.05
}
}

}

// scalastyle:off magic.number
trait LightGBMTestUtils extends TestBase {
Expand Down Expand Up @@ -360,6 +388,21 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assert(metric > 0.8)
}

test("Verify LightGBM Classifier updating learning_rate on training by using LightGBMDelegate") {
val Array(train, _) = indexedBankTrainDF.randomSplit(Array(0.8, 0.2), seed)
val delegate = new TrainDelegate()
val untrainedModel = baseModel
.setCategoricalSlotNames(indexedBankTrainDF.columns.filter(_.startsWith("c_")))
.setDelegate(delegate)
.setLearningRate(0.1)
.setNumIterations(2) // expected learning_rate: iters 0 => 0.1, iters 1 => 0.005

val model = untrainedModel.fit(train)

// Verify updating learning_rate
assert(model.getModel.model.contains("learning_rate: 0.005"))
}

test("Verify LightGBM Classifier leaf prediction") {
val Array(train, test) = indexedBankTrainDF.randomSplit(Array(0.8, 0.2), seed)
val untrainedModel = baseModel
Expand Down

0 comments on commit b8c527f

Please sign in to comment.