Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: fix lightgbm stuck in multiclass scenario and added stratified r…
…epartition transformer (#618)
- Loading branch information
1 parent
85fb3fc
commit d518b8a
Showing
7 changed files
with
285 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
src/main/scala/com/microsoft/ml/spark/stages/StratifiedRepartition.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.stages | ||
|
||
import com.microsoft.ml.spark.core.contracts.{HasLabelCol, Wrappable} | ||
import org.apache.spark.RangePartitioner | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared.HasSeed | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
||
/** Constants for <code>StratifiedRepartition</code>. */ | ||
object SPConstants { | ||
val Count = "count" | ||
val Equal = "equal" | ||
val Original = "original" | ||
val Mixed = "mixed" | ||
} | ||
|
||
object StratifiedRepartition extends DefaultParamsReadable[DropColumns] | ||
|
||
/** <code>StratifiedRepartition</code> repartitions the DataFrame such that each label is selected in each partition. | ||
* This may be necessary in some cases such as in LightGBM multiclass classification, where it is necessary for | ||
* at least one instance of each label to be present on each partition. | ||
*/ | ||
class StratifiedRepartition(val uid: String) extends Transformer with Wrappable | ||
with DefaultParamsWritable with HasLabelCol with HasSeed { | ||
def this() = this(Identifiable.randomUID("StratifiedRepartition")) | ||
|
||
val mode = new Param[String](this, "mode", | ||
"Specify equal to repartition with replacement across all labels, specify " + | ||
"original to keep the ratios in the original dataset, or specify mixed to use a heuristic") | ||
setDefault(mode -> SPConstants.Mixed) | ||
|
||
def getMode: String = $(mode) | ||
def setMode(value: String): this.type = set(mode, value) | ||
|
||
/** @param dataset - The input dataset, to be transformed | ||
* @return The DataFrame that results from stratified repartitioning | ||
*/ | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
// Count unique values in label column | ||
val distinctLabelCounts = dataset.select(getLabelCol).groupBy(getLabelCol).count().collect() | ||
val labelToCount = distinctLabelCounts.map(row => (row.getInt(0), row.getLong(1))) | ||
val labelToFraction = | ||
getMode match { | ||
case SPConstants.Equal => getEqualLabelCount(labelToCount, dataset) | ||
case SPConstants.Mixed => { | ||
val equalLabelToCount = getEqualLabelCount(labelToCount, dataset) | ||
val normalizedRatio = equalLabelToCount.map { case (label, count) => count }.sum / labelToCount.size | ||
labelToCount.map { case (label, count) => (label, count / normalizedRatio)}.toMap | ||
} | ||
case SPConstants.Original => labelToCount.map { case (label, count) => (label, 1.0) }.toMap | ||
case _ => throw new Exception(s"Unknown mode specified to StratifiedRepartition: $getMode") | ||
} | ||
val labelColIndex = dataset.schema.fieldIndex(getLabelCol) | ||
val spdata = dataset.toDF().rdd.keyBy(row => row.getInt(labelColIndex)) | ||
.sampleByKeyExact(true, labelToFraction, getSeed) | ||
.mapPartitions(keyToRow => keyToRow.zipWithIndex.map { case ((key, row), index) => (index, row) }) | ||
val rangePartitioner = new RangePartitioner(dataset.rdd.getNumPartitions, spdata) | ||
val rspdata = spdata.partitionBy(rangePartitioner).mapPartitions(keyToRow => | ||
keyToRow.map{case (key, row) => row}).persist() | ||
dataset.sqlContext.createDataFrame(rspdata, dataset.schema) | ||
} | ||
|
||
private def getEqualLabelCount(labelToCount: Array[(Int, Long)], dataset: Dataset[_]): Map[Int, Double] = { | ||
val maxLabelCount = Math.max(labelToCount.map { case (label, count) => count }.max, dataset.rdd.getNumPartitions) | ||
labelToCount.map { case (label, count) => (label, maxLabelCount.toDouble / count) }.toMap | ||
} | ||
|
||
def transformSchema(schema: StructType): StructType = schema | ||
|
||
def copy(extra: ParamMap): DropColumns = defaultCopy(extra) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
src/test/scala/com/microsoft/ml/spark/stages/StratifiedRepartitionSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.stages | ||
|
||
import com.microsoft.ml.spark.core.test.base.TestBase | ||
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing} | ||
import org.apache.spark.TaskContext | ||
import org.apache.spark.ml.util.MLReadable | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.encoders.RowEncoder | ||
import org.apache.spark.sql.types.{IntegerType, StringType, StructType} | ||
|
||
class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] { | ||
|
||
import session.implicits._ | ||
|
||
val values = "values" | ||
val colors = "colors" | ||
val const = "const" | ||
|
||
val input = Seq( | ||
(0, "Blue", 2), | ||
(0, "Red", 2), | ||
(0, "Green", 2), | ||
(1, "Purple", 2), | ||
(1, "Orange", 2), | ||
(1, "Indigo", 2), | ||
(2, "Violet", 2), | ||
(2, "Black", 2), | ||
(2, "White", 2), | ||
(3, "Gray", 2), | ||
(3, "Yellow", 2), | ||
(3, "Cerulean", 2) | ||
).toDF(values, colors, const) | ||
|
||
test("Assert doing a stratified repartition will ensure all keys exist across all partitions") { | ||
val inputSchema = new StructType() | ||
.add(values, IntegerType).add(colors, StringType).add(const, IntegerType) | ||
val inputEnc = RowEncoder(inputSchema) | ||
val valuesFieldIndex = inputSchema.fieldIndex(values) | ||
val numPartitions = 3 | ||
val trainData = input.repartition(numPartitions).select(values, colors, const) | ||
.mapPartitions(iter => { | ||
val ctx = TaskContext.get | ||
val partId = ctx.partitionId | ||
// Remove all instances of 0 class on partition 1 | ||
if (partId == 1) { | ||
iter.flatMap(row => { | ||
if (row.getInt(valuesFieldIndex) <= 0) | ||
None | ||
else Some(row) | ||
}) | ||
} else { | ||
// Add back at least 3 instances on other partitions | ||
val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2)) | ||
(iter.toList.union(oneOfEachExample).union(oneOfEachExample).union(oneOfEachExample)).toIterator | ||
} | ||
})(inputEnc).cache() | ||
// Some debug to understand what data is on which partition | ||
trainData.foreachPartition { rows => | ||
rows.foreach { row => | ||
val ctx = TaskContext.get | ||
val partId = ctx.partitionId | ||
println(s"Row: $row partition id: $partId") | ||
} | ||
} | ||
val stratifiedInputData = new StratifiedRepartition().setLabelCol(values) | ||
.setMode(SPConstants.Equal).transform(trainData) | ||
// Assert stratified data contains all keys across all partitions, with extra count | ||
// for it to be evaluated | ||
stratifiedInputData | ||
.mapPartitions(iter => { | ||
val actualLabels = iter.map(row => row.getInt(valuesFieldIndex)) | ||
.toArray.distinct.sorted.toList | ||
val expectedLabels = (0 to 3).toList | ||
if (actualLabels != expectedLabels) | ||
throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels") | ||
iter | ||
})(inputEnc).count() | ||
val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values) | ||
.setMode(SPConstants.Mixed).transform(trainData) | ||
assert(stratifiedMixedInputData.count() >= trainData.count()) | ||
val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values) | ||
.setMode(SPConstants.Original).transform(trainData) | ||
assert(stratifiedOriginalInputData.count() == trainData.count()) | ||
} | ||
|
||
def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject( | ||
new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input)) | ||
|
||
def reader: MLReadable[_] = StratifiedRepartition | ||
} |