Skip to content

Commit

Permalink
[SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF to gene…
Browse files Browse the repository at this point in the history
…rate new label at binary reduction

Currently OneVsRest use UDF to generate new binary label during training.
Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8519 from yanboliang/spark-10349.
  • Loading branch information
yanboliang authored and mengxr committed Aug 31, 2015
1 parent 540bdee commit fe16fd0
Showing 1 changed file with 2 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] (
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val initUDF = udf { () => Map[Int, Double]() }
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, initUDF())

// persist if underlying dataset is not persistent.
Expand Down Expand Up @@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String)

// create k columns, one for each binary classifier.
val models = Range(0, numClasses).par.map { index =>
val labelUDF = udf { (label: Double) =>
if (label.toInt == index) 1.0 else 0.0
}

// generate new label metadata for the binary problem.
// TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val trainingDataset =
multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(
labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
Expand Down

0 comments on commit fe16fd0

Please sign in to comment.