From 5e068e3c37d146c9cde0801d321b69709a87b591 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Feb 2015 22:41:37 -0800 Subject: [PATCH] [MLlib] Minor: UDF style update. --- .../apache/spark/ml/classification/LogisticRegression.scala | 4 +++- .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index df90078de148f..bdb498e565b88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -137,7 +137,9 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-margin)) } : Double) val t = map(threshold) - val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double) + val predictFunction = udf { score: Double => + if (score > t) 1.0 else 0.0 + } dataset .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol))) .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index c7bec7a845a11..aa6d89030727b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -129,13 +129,13 @@ class ALSModel private[ml] ( // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => { + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } - } : Float) + } dataset .join(users, dataset(map(userCol)) === users("id"), "left") .join(items, dataset(map(itemCol)) === items("id"), "left")