From b1c02277c9490f7866834d4fb8028098b62066b9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 7 Dec 2017 19:05:08 +0800 Subject: [PATCH 1/3] create pr --- .../spark/ml/classification/Classifier.scala | 4 ++ .../spark/ml/classification/OneVsRest.scala | 58 +++++-------------- 2 files changed, 17 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc0b49d48d323..5ed444a88e36f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -213,6 +213,10 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur */ protected def predictRaw(features: FeaturesType): Vector + protected[classification] def predictRaw(features: Any): Vector = { + predictRaw(features.asInstanceOf[FeaturesType]) + } + /** * Given a vector of raw predictions, select the predicted label. * This may be overridden to support thresholds which favor particular labels. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 3ab99b35ece2b..9cee2ce57a2b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -156,54 +156,22 @@ final class OneVsRestModel private[ml] ( // Check schema transformSchema(dataset.schema, logging = true) - // determine the input columns: these need to be passed through - val origCols = dataset.schema.map(f => col(f.name)) - - // add an accumulator column to store predictions of all the models - val accColName = "mbc$acc" + UUID.randomUUID().toString - val initUDF = udf { () => Map[Int, Double]() } - val newDataset = dataset.withColumn(accColName, initUDF()) - - // persist if underlying dataset is not persistent. - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) { - newDataset.persist(StorageLevel.MEMORY_AND_DISK) - } - - // update the accumulator column with the result of prediction of models - val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { - case (df, (model, index)) => - val rawPredictionCol = model.getRawPredictionCol - val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) - - // add temporary column to store intermediate scores and update - val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => - predictions + ((index, prediction(1))) + val predictUDF = udf { (features: Any) => + var i = 0 + var maxIndex = -1 + var maxPred = Double.MinValue + while (i < models.length) { + val pred = models(i).predictRaw(features)(1) + if (pred > maxPred) { + maxIndex = i + maxPred = pred } - model.setFeaturesCol($(featuresCol)) - val transformedDataset = model.transform(df).select(columns: _*) - val updatedDataset = transformedDataset - .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) - val newColumns = origCols ++ List(col(tmpColName)) - - // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) - } - - if (handlePersistence) { - newDataset.unpersist() - } - - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble + i += 1 + } + maxIndex } - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))), labelMetadata) } @Since("1.4.1") From 45f70c87822b7c4bd872c0729509eeed0a2f9441 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 8 Dec 2017 10:22:07 +0800 Subject: [PATCH 2/3] update pr --- .../org/apache/spark/ml/classification/Classifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 5ed444a88e36f..6d9bd8f7591fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -213,7 +213,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur */ protected def predictRaw(features: FeaturesType): Vector - protected[classification] def predictRaw(features: Any): Vector = { + protected[classification] def predictRawAsFeaturesType(features: Any): Vector = { predictRaw(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 9cee2ce57a2b0..762a13dc3fd74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.classification -import java.util.UUID - import scala.concurrent.Future import scala.concurrent.duration.Duration import scala.language.existentials @@ -32,7 +30,6 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -161,7 +158,7 @@ final class OneVsRestModel private[ml] ( var maxIndex = -1 var maxPred = Double.MinValue while (i < models.length) { - val pred = models(i).predictRaw(features)(1) + val pred = models(i).predictRawAsFeaturesType(features)(1) if (pred > maxPred) { maxIndex = i maxPred = pred From b95795fefce3dcb0bc3b91cbf70f09d7747affd1 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 8 Dec 2017 11:18:33 +0800 Subject: [PATCH 3/3] int => double --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 762a13dc3fd74..68ab32c051183 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -155,7 +155,7 @@ final class OneVsRestModel private[ml] ( val predictUDF = udf { (features: Any) => var i = 0 - var maxIndex = -1 + var maxIndex = Double.NaN var maxPred = Double.MinValue while (i < models.length) { val pred = models(i).predictRawAsFeaturesType(features)(1)