Skip to content

Commit

Permalink
fix: fix lightgbm stuck in multiclass scenario and added stratified r…
Browse files Browse the repository at this point in the history
…epartition transformer (#618)
  • Loading branch information
imatiach-msft authored and mhamilton723 committed Aug 20, 2019
1 parent 85fb3fc commit d518b8a
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 7 deletions.
Expand Up @@ -38,20 +38,27 @@ class LightGBMClassifier(override val uid: String)
def getIsUnbalance: Boolean = $(isUnbalance)
def setIsUnbalance(value: Boolean): this.type = set(isUnbalance, value)

val generateMissingLabels = new BooleanParam(this, "generateMissingLabels",
"Instead of failing in lightgbm, generates dummy rows with missing labels within partitions")
setDefault(generateMissingLabels -> false)

def getGenerateMissingLabels: Boolean = $(generateMissingLabels)
def setGenerateMissingLabels(value: Boolean): this.type = set(generateMissingLabels, value)

def getTrainParams(numWorkers: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = {
/* The native code for getting numClasses is always 1 unless it is multiclass-classification problem
* so we infer the actual numClasses from the dataset here
*/
val actualNumClasses = getNumClasses(dataset)
val metric =
if (getObjective == LightGBMConstants.BinaryObjective) "binary_logloss,auc"
else "multiclass"
else LightGBMConstants.MulticlassObjective
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
ClassifierTrainParams(getParallelism, getNumIterations, getLearningRate, getNumLeaves,
getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers, getObjective, modelStr,
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, metric, getBoostFromAverage,
getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric)
getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getGenerateMissingLabels)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down
Expand Up @@ -19,6 +19,9 @@ object LightGBMConstants {
/** Binary classification objective
*/
val BinaryObjective: String = "binary"
/** Multiclass classification objective
*/
val MulticlassObjective: String = "multiclass"
/** Ignore worker status, used to ignore workers that get empty partitions
*/
val IgnoreStatus: String = "ignore"
Expand Down
Expand Up @@ -52,7 +52,7 @@ case class ClassifierTrainParams(val parallelism: String, val numIterations: Int
val isUnbalance: Boolean, val verbosity: Int, val categoricalFeatures: Array[Int],
val numClass: Int, val metric: String, val boostFromAverage: Boolean,
val boostingType: String, val lambdaL1: Double, val lambdaL2: Double,
val isProvideTrainingMetric: Boolean)
val isProvideTrainingMetric: Boolean, val generateMissingLabels: Boolean)
extends TrainParams {
override def toString(): String = {
val extraStr =
Expand Down
37 changes: 34 additions & 3 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala
Expand Up @@ -15,6 +15,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
import org.apache.spark.BarrierTaskContext
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

case class NetworkParams(defaultListenPort: Int, addr: String, port: Int, barrierExecutionMode: Boolean)

Expand All @@ -23,9 +24,38 @@ private object TrainUtils extends Serializable {
def generateDataset(rows: Array[Row], labelColumn: String, featuresColumn: String,
weightColumn: Option[String], initScoreColumn: Option[String], groupColumn: Option[String],
referenceDataset: Option[LightGBMDataset], schema: StructType,
log: Logger): Option[LightGBMDataset] = {
log: Logger, trainParams: TrainParams): Option[LightGBMDataset] = {
val numRows = rows.length
val labels = rows.map(row => row.getDouble(schema.fieldIndex(labelColumn)))
if (trainParams.objective == LightGBMConstants.MulticlassObjective ||
trainParams.objective == LightGBMConstants.BinaryObjective) {
val distinctLabels = labels.distinct.map(_.toInt).sorted
// TODO: Temporary hack to append missing labels for debugging, off by default
// try to figure out a better fix in lightgbm
if (trainParams.asInstanceOf[ClassifierTrainParams].generateMissingLabels) {
val (count, missingLabels) =
distinctLabels.foldLeft((-1, List[Int]())) {
case ((baseCount, baseLabels), newLabel) => {
if (newLabel == baseCount + 1) (newLabel, baseLabels)
else (baseCount + 1, baseCount + 1 :: baseLabels)
}
}
if (!missingLabels.isEmpty) {
// Append missing labels to rows
val newRows = rows.take(missingLabels.size).zip(missingLabels).map { case (row, label) =>
val rowAsArray = row.toSeq.toArray
rowAsArray.update(schema.fieldIndex(labelColumn), label.toDouble)
new GenericRowWithSchema(rowAsArray, row.schema) }
return generateDataset(rows ++ newRows, labelColumn, featuresColumn, weightColumn, initScoreColumn,
groupColumn, referenceDataset, schema, log, trainParams)
}
} else {
val errMsg = "For classification, label values must start from 0 and increase " +
"by 1 to n for each partition."
distinctLabels.foldLeft(-1)((base, newLabel) => if (newLabel == base + 1) newLabel else
throw new Exception(s"$errMsg Missing label ${base + 1}, unique labels ${distinctLabels.mkString(",")}"))
}
}
val hrow = rows.head
var datasetPtr: Option[LightGBMDataset] = None
datasetPtr =
Expand Down Expand Up @@ -226,10 +256,11 @@ private object TrainUtils extends Serializable {
var validDatasetPtr: Option[LightGBMDataset] = None
try {
trainDatasetPtr = generateDataset(rows, labelColumn, featuresColumn,
weightColumn, initScoreColumn, groupColumn, None, schema, log)
weightColumn, initScoreColumn, groupColumn, None, schema, log, trainParams)
if (validationData.isDefined) {
validDatasetPtr = generateDataset(validationData.get.value, labelColumn,
featuresColumn, weightColumn, initScoreColumn, groupColumn, trainDatasetPtr, schema, log)
featuresColumn, weightColumn, initScoreColumn, groupColumn, trainDatasetPtr,
schema, log, trainParams)
}
var boosterPtr: Option[SWIGTYPE_p_void] = None
try {
Expand Down
@@ -0,0 +1,77 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.{HasLabelCol, Wrappable}
import org.apache.spark.RangePartitioner
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}

/** Constants for <code>StratifiedRepartition</code>. */
object SPConstants {
val Count = "count"
val Equal = "equal"
val Original = "original"
val Mixed = "mixed"
}

object StratifiedRepartition extends DefaultParamsReadable[DropColumns]

/** <code>StratifiedRepartition</code> repartitions the DataFrame such that each label is selected in each partition.
* This may be necessary in some cases such as in LightGBM multiclass classification, where it is necessary for
* at least one instance of each label to be present on each partition.
*/
class StratifiedRepartition(val uid: String) extends Transformer with Wrappable
with DefaultParamsWritable with HasLabelCol with HasSeed {
def this() = this(Identifiable.randomUID("StratifiedRepartition"))

val mode = new Param[String](this, "mode",
"Specify equal to repartition with replacement across all labels, specify " +
"original to keep the ratios in the original dataset, or specify mixed to use a heuristic")
setDefault(mode -> SPConstants.Mixed)

def getMode: String = $(mode)
def setMode(value: String): this.type = set(mode, value)

/** @param dataset - The input dataset, to be transformed
* @return The DataFrame that results from stratified repartitioning
*/
override def transform(dataset: Dataset[_]): DataFrame = {
// Count unique values in label column
val distinctLabelCounts = dataset.select(getLabelCol).groupBy(getLabelCol).count().collect()
val labelToCount = distinctLabelCounts.map(row => (row.getInt(0), row.getLong(1)))
val labelToFraction =
getMode match {
case SPConstants.Equal => getEqualLabelCount(labelToCount, dataset)
case SPConstants.Mixed => {
val equalLabelToCount = getEqualLabelCount(labelToCount, dataset)
val normalizedRatio = equalLabelToCount.map { case (label, count) => count }.sum / labelToCount.size
labelToCount.map { case (label, count) => (label, count / normalizedRatio)}.toMap
}
case SPConstants.Original => labelToCount.map { case (label, count) => (label, 1.0) }.toMap
case _ => throw new Exception(s"Unknown mode specified to StratifiedRepartition: $getMode")
}
val labelColIndex = dataset.schema.fieldIndex(getLabelCol)
val spdata = dataset.toDF().rdd.keyBy(row => row.getInt(labelColIndex))
.sampleByKeyExact(true, labelToFraction, getSeed)
.mapPartitions(keyToRow => keyToRow.zipWithIndex.map { case ((key, row), index) => (index, row) })
val rangePartitioner = new RangePartitioner(dataset.rdd.getNumPartitions, spdata)
val rspdata = spdata.partitionBy(rangePartitioner).mapPartitions(keyToRow =>
keyToRow.map{case (key, row) => row}).persist()
dataset.sqlContext.createDataFrame(rspdata, dataset.schema)
}

private def getEqualLabelCount(labelToCount: Array[(Int, Long)], dataset: Dataset[_]): Map[Int, Double] = {
val maxLabelCount = Math.max(labelToCount.map { case (label, count) => count }.max, dataset.rdd.getNumPartitions)
labelToCount.map { case (label, count) => (label, maxLabelCount.toDouble / count) }.toMap
}

def transformSchema(schema: StructType): StructType = schema

def copy(extra: ParamMap): DropColumns = defaultCopy(extra)
}
Expand Up @@ -11,7 +11,7 @@ import com.microsoft.ml.spark.core.test.benchmarks.{Benchmarks, DatasetUtils}
import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject}
import com.microsoft.ml.spark.featurize.ValueIndexer
import com.microsoft.ml.spark.lightgbm._
import com.microsoft.ml.spark.stages.MultiColumnAdapter
import com.microsoft.ml.spark.stages.{MultiColumnAdapter, SPConstants, StratifiedRepartition}
import org.apache.commons.io.FileUtils
import org.apache.spark.TaskContext
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
Expand Down Expand Up @@ -387,6 +387,73 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assert(model != null)
}

test("Verify LightGBM Classifier won't get stuck on unbalanced classes in multiclass classification") {
assume(!isWindows)
// Increment port index
portIndex += numPartitions
val fileName = "BreastTissue.csv"
val labelColumnName = "Class"
val fileLocation = DatasetUtils.multiclassTrainFile(fileName).toString
val dataset = readCSV(fileName, fileLocation).repartition(numPartitions)
val featuresColumn = "_features"
val featurizer = LightGBMUtils.featurizeData(dataset, labelColumnName, featuresColumn)
val predCol = "pred"
val tmpTrainData = featurizer.transform(dataset)
val labelizer = new ValueIndexer().setInputCol(labelColumnName).setOutputCol(labelColumnName).fit(tmpTrainData)
val labelizedData = labelizer.transform(tmpTrainData).select(labelColumnName, featuresColumn)
val lgbm = new LightGBMClassifier()
.setLabelCol(labelColumnName)
.setFeaturesCol(featuresColumn)
.setPredictionCol(predCol)
.setDefaultListenPort(LightGBMConstants.DefaultLocalListenPort + portIndex)
.setObjective(multiclassObject)

val infoSchema = new StructType()
.add(labelColumnName, IntegerType).add(featuresColumn, org.apache.spark.ml.linalg.SQLDataTypes.VectorType)
val infoEnc = RowEncoder(infoSchema)
val labelColumnIndex = labelizedData.schema.fieldIndex(labelColumnName)
val trainData = labelizedData
.mapPartitions(iter => {
val ctx = TaskContext.get
val partId = ctx.partitionId
// Remove all instances of some classes
if (partId == 1) {
iter.flatMap(row => {
if (row.getInt(labelColumnIndex) <= 2)
None
else Some(row)
})
} else {
iter
}
})(infoEnc)
// Validate fit fails and doesn't get stuck
assertThrows[Exception] {
lgbm.fit(trainData)
}
// Validate using special mode works
val missingModel = lgbm.setGenerateMissingLabels(true).fit(trainData)
missingModel.transform(trainData)

val stratifiedTrainData = new StratifiedRepartition().setLabelCol(labelColumnName)
.setMode(SPConstants.Equal).transform(trainData)
// Assert stratified train data contains all keys across all partitions, with extra count
// for it to be evaluated
stratifiedTrainData.select(labelColumnName, featuresColumn)
.mapPartitions(iter => {
val actualLabels = iter.map(row => row.getInt(labelColumnIndex))
.toArray.distinct.sorted.toList
val expectedLabels = (0 to 5).toList
if (actualLabels != expectedLabels)
throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
iter
})(infoEnc).count()
// Validate with stratified repartitioned dataset fit passes
val model = lgbm.setGenerateMissingLabels(false).fit(stratifiedTrainData)
model.transform(stratifiedTrainData)
assert(model != null)
}

/** Reads a CSV file given the file name and file location.
* @param fileName The name of the csv file.
* @param fileLocation The full path to the csv file.
Expand Down
@@ -0,0 +1,93 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

import session.implicits._

val values = "values"
val colors = "colors"
val const = "const"

val input = Seq(
(0, "Blue", 2),
(0, "Red", 2),
(0, "Green", 2),
(1, "Purple", 2),
(1, "Orange", 2),
(1, "Indigo", 2),
(2, "Violet", 2),
(2, "Black", 2),
(2, "White", 2),
(3, "Gray", 2),
(3, "Yellow", 2),
(3, "Cerulean", 2)
).toDF(values, colors, const)

test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
val inputSchema = new StructType()
.add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
val inputEnc = RowEncoder(inputSchema)
val valuesFieldIndex = inputSchema.fieldIndex(values)
val numPartitions = 3
val trainData = input.repartition(numPartitions).select(values, colors, const)
.mapPartitions(iter => {
val ctx = TaskContext.get
val partId = ctx.partitionId
// Remove all instances of 0 class on partition 1
if (partId == 1) {
iter.flatMap(row => {
if (row.getInt(valuesFieldIndex) <= 0)
None
else Some(row)
})
} else {
// Add back at least 3 instances on other partitions
val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
(iter.toList.union(oneOfEachExample).union(oneOfEachExample).union(oneOfEachExample)).toIterator
}
})(inputEnc).cache()
// Some debug to understand what data is on which partition
trainData.foreachPartition { rows =>
rows.foreach { row =>
val ctx = TaskContext.get
val partId = ctx.partitionId
println(s"Row: $row partition id: $partId")
}
}
val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
.setMode(SPConstants.Equal).transform(trainData)
// Assert stratified data contains all keys across all partitions, with extra count
// for it to be evaluated
stratifiedInputData
.mapPartitions(iter => {
val actualLabels = iter.map(row => row.getInt(valuesFieldIndex))
.toArray.distinct.sorted.toList
val expectedLabels = (0 to 3).toList
if (actualLabels != expectedLabels)
throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
iter
})(inputEnc).count()
val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
.setMode(SPConstants.Mixed).transform(trainData)
assert(stratifiedMixedInputData.count() >= trainData.count())
val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
.setMode(SPConstants.Original).transform(trainData)
assert(stratifiedOriginalInputData.count() == trainData.count())
}

def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

def reader: MLReadable[_] = StratifiedRepartition
}

0 comments on commit d518b8a

Please sign in to comment.