Skip to content

Commit

Permalink
add barrier execution mode
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jul 3, 2019
1 parent aac0536 commit 29c15cb
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 31 deletions.
18 changes: 18 additions & 0 deletions docs/lightgbm.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,21 @@ Models built can be saved as SparkML pipeline with native LightGBM model
using `saveNativeModel()`. Additionally, they are fully compatible with [PMML](https://en.wikipedia.org/wiki/Predictive_Model_Markup_Language) and
can be converted to PMML format through the
[JPMML-SparkML-LightGBM](https://github.com/alipay/jpmml-sparkml-lightgbm) plugin.

### Barrier Execution Mode

By default LightGBM uses regular spark paradigm for launching tasks and communicates with the driver to coordinate task execution.
The driver thread aggregates all task host:port information and then communicates the full list back to the workers in order for NetworkInit to be called.
There have been some issues on certain cluster configurations because the driver needs to know how many tasks there are, and this computation is surprisingly non-trivial in spark.
With the next v0.18 release there is a new UseBarrierExecutionMode flag, which when activated uses the barrier() stage to block all tasks.
The barrier execution mode simplifies the logic to aggregate host:port information across all tasks, so the driver will no longer need to precompute the number of tasks in advance.
To use it in scala, you can call setUseBarrierExecutionMode(true), for example:

```
val lgbm = new LightGBMClassifier()
.setLabelCol(labelColumn)
.setObjective(binaryObjective)
.setUseBarrierExecutionMode(true)
...
<train classifier>
```
16 changes: 10 additions & 6 deletions src/lightgbm/src/main/scala/LightGBMBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
// Reduce number of partitions to number of executor cores
val df = dataset.select(trainingCols.map(name => col(name)):_*).toDF().coalesce(numWorkers)
val (inetAddress, port, future) =
LightGBMUtils.createDriverNodesThread(numWorkers, df, log, getTimeout)
LightGBMUtils.createDriverNodesThread(numWorkers, df, log, getTimeout, getUseBarrierExecutionMode)

/* Run a parallel job via map partitions to initialize the native library and network,
* translate the data to the LightGBM in-memory representation and train the models
Expand All @@ -73,17 +73,21 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
categoricalSlotIndexesArr, categoricalSlotNamesArr)
val trainParams = getTrainParams(numWorkers, categoricalIndexes, dataset)
log.info(s"LightGBM parameters: ${trainParams.toString()}")
val networkParams = NetworkParams(getDefaultListenPort, inetAddress, port)
val networkParams = NetworkParams(getDefaultListenPort, inetAddress, port, getUseBarrierExecutionMode)
val validationData =
if (get(validationIndicatorCol).isDefined && dataset.columns.contains(getValidationIndicatorCol))
Some(sc.broadcast(df.filter(x => x.getBoolean(x.fieldIndex(getValidationIndicatorCol))).collect()))
else None
val preprocessedDF = preprocessData(df)
val schema = preprocessedDF.schema
val lightGBMBooster = preprocessedDF
.mapPartitions(TrainUtils.trainLightGBM(networkParams, getLabelCol, getFeaturesCol, get(weightCol),
get(initScoreCol), getOptGroupCol, validationData, log, trainParams, numCoresPerExec, schema))(encoder)
.reduce((booster1, _) => booster1)
val mapPartitionsFunc = TrainUtils.trainLightGBM(networkParams, getLabelCol, getFeaturesCol,
get(weightCol), get(initScoreCol), getOptGroupCol, validationData, log, trainParams, numCoresPerExec, schema)(_)
val lightGBMBooster =
if (getUseBarrierExecutionMode) {
preprocessedDF.rdd.barrier().mapPartitions(mapPartitionsFunc).reduce((booster1, _) => booster1)
} else {
preprocessedDF.mapPartitions(mapPartitionsFunc)(encoder).reduce((booster1, _) => booster1)
}
// Wait for future to complete (should be done by now)
Await.result(future, Duration(getTimeout, SECONDS))
getModel(trainParams, lightGBMBooster)
Expand Down
4 changes: 4 additions & 0 deletions src/lightgbm/src/main/scala/LightGBMConstants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ object LightGBMConstants {
/** Ignore worker status, used to ignore workers that get empty partitions
*/
val ignoreStatus: String = "ignore"
/** Barrier execution flag telling driver that all tasks have completed
* sending port and host information
*/
val finishedStatus: String = "finished"
}
30 changes: 20 additions & 10 deletions src/lightgbm/src/main/scala/LightGBMParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ package com.microsoft.ml.spark

import org.apache.spark.ml.param._
import org.apache.spark.ml.util.DefaultParamsWritable
import org.apache.spark.sql.DataFrame

/** Defines common parameters across all LightGBM learners.
/** Defines common LightGBM execution parameters.
*/
trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeightCol
with HasValidationIndicatorCol with HasInitScoreCol {
trait LightGBMExecutionParams extends Wrappable {
val parallelism = new Param[String](this, "parallelism",
"Tree learner parallelism, can be set to data_parallel or voting_parallel")
setDefault(parallelism->"data_parallel")
Expand All @@ -26,6 +24,24 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight

setDefault(defaultListenPort -> LightGBMConstants.defaultLocalListenPort)

val timeout = new DoubleParam(this, "timeout", "Timeout in seconds")
setDefault(timeout -> 1200)

def getTimeout: Double = $(timeout)
def setTimeout(value: Double): this.type = set(timeout, value)

val useBarrierExecutionMode = new BooleanParam(this, "useBarrierExecutionMode",
"Use new barrier execution mode in Beta testing, off by default.")
setDefault(useBarrierExecutionMode -> false)

def getUseBarrierExecutionMode: Boolean = $(useBarrierExecutionMode)
def setUseBarrierExecutionMode(value: Boolean): this.type = set(useBarrierExecutionMode, value)
}

/** Defines common parameters across all LightGBM learners.
*/
trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeightCol
with HasValidationIndicatorCol with HasInitScoreCol with LightGBMExecutionParams {
val numIterations = new IntParam(this, "numIterations",
"Number of iterations, LightGBM constructs num_class * num_iterations trees")
setDefault(numIterations->100)
Expand Down Expand Up @@ -102,12 +118,6 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight
def getMinSumHessianInLeaf: Double = $(minSumHessianInLeaf)
def setMinSumHessianInLeaf(value: Double): this.type = set(minSumHessianInLeaf, value)

val timeout = new DoubleParam(this, "timeout", "Timeout in seconds")
setDefault(timeout -> 1200)

def getTimeout: Double = $(timeout)
def setTimeout(value: Double): this.type = set(timeout, value)

val modelString = new Param[String](this, "modelString", "LightGBM model to retrain")
setDefault(modelString -> "")

Expand Down
50 changes: 36 additions & 14 deletions src/lightgbm/src/main/scala/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ object LightGBMUtils {
* @return The address and port of the driver socket.
*/
def createDriverNodesThread(numWorkers: Int, df: DataFrame,
log: Logger, timeout: Double): (String, Int, Future[Unit]) = {
log: Logger, timeout: Double,
barrierExecutionMode: Boolean): (String, Int, Future[Unit]) = {
// Start a thread and open port to listen on
implicit val context = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor())
val driverServerSocket = new ServerSocket(0)
Expand All @@ -109,19 +110,40 @@ object LightGBMUtils {
val f = Future {
var emptyWorkerCounter = 0
val hostAndPorts = ListBuffer[(Socket, String)]()
log.info(s"driver expecting $numWorkers connections...")
while (hostAndPorts.size + emptyWorkerCounter < numWorkers) {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.ignoreStatus) {
log.info("driver received ignore status from worker")
emptyWorkerCounter += 1
} else {
log.info(s"driver received socket from worker: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
if (barrierExecutionMode) {
log.info(s"driver using barrier execution mode")
var finished = false
while (!finished) {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.finishedStatus) {
log.info("driver received all workers from barrier stage")
finished = true
} else if (comm == LightGBMConstants.ignoreStatus) {
log.info("driver received ignore status from worker")
} else {
log.info(s"driver received socket from worker: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
}
}
} else {
log.info(s"driver expecting $numWorkers connections...")
while (hostAndPorts.size + emptyWorkerCounter < numWorkers) {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.ignoreStatus) {
log.info("driver received ignore status from worker")
emptyWorkerCounter += 1
} else {
log.info(s"driver received socket from worker: $comm")
val socketAndComm = (driverSocket, comm)
hostAndPorts += socketAndComm
}
}
}
// Concatenate with commas, eg: host1:port1,host2:port2, ... etc
Expand Down
25 changes: 24 additions & 1 deletion src/lightgbm/src/main/scala/TrainUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import java.net._

import com.microsoft.ml.lightgbm._
import com.microsoft.ml.spark.StreamUtilities.using
import org.apache.spark.BarrierTaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Row
import org.slf4j.Logger

case class NetworkParams(defaultListenPort: Int, addr: String, port: Int)
case class NetworkParams(defaultListenPort: Int, addr: String, port: Int, barrierExecutionMode: Boolean)

private object TrainUtils extends Serializable {

Expand Down Expand Up @@ -272,6 +273,20 @@ private object TrainUtils extends Serializable {
workerServerSocket
}

def setFinishedStatus(networkParams: NetworkParams,
localListenPort: Int, log: Logger): Unit = {
using(new Socket(networkParams.addr, networkParams.port)) {
driverSocket =>
using(new BufferedWriter(new OutputStreamWriter(driverSocket.getOutputStream))) {
driverOutput =>
log.info("sending finished status to driver")
// If barrier execution mode enabled, create a barrier across tasks
driverOutput.write(s"${LightGBMConstants.finishedStatus}\n")
driverOutput.flush()
}.get
}.get
}

def getNetworkInitNodes(networkParams: NetworkParams,
localListenPort: Int, log: Logger,
emptyPartition: Boolean): String = {
Expand All @@ -295,6 +310,14 @@ private object TrainUtils extends Serializable {
// Send the current host:port to the driver
driverOutput.write(s"$workerStatus\n")
driverOutput.flush()
// If barrier execution mode enabled, create a barrier across tasks
if (networkParams.barrierExecutionMode) {
val context = BarrierTaskContext.get()
context.barrier()
if (context.partitionId() == 0) {
setFinishedStatus(networkParams, localListenPort, log)
}
}
if (workerStatus != LightGBMConstants.ignoreStatus) {
// Wait to get the list of nodes from the driver
val nodes = driverInput.readLine()
Expand Down
1 change: 1 addition & 0 deletions src/lightgbm/src/test/scala/VerifyLightGBMClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setNumLeaves(5)
.setNumIterations(10)
.setObjective(binaryObjective)
.setUseBarrierExecutionMode(true)

val paramGrid = new ParamGridBuilder()
.addGrid(lgbm.numLeaves, Array(5, 10))
Expand Down

0 comments on commit 29c15cb

Please sign in to comment.