From 0a97b59d87200debfbd759d3da64c626cc6b6ac4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 9 Jul 2015 15:04:09 -0700 Subject: [PATCH] [SPARK-8400] Allow ml.ALS to use -1 block size to auto-configure --- .../org/apache/spark/ml/param/params.scala | 13 +++++++++++ .../apache/spark/ml/recommendation/ALS.scala | 23 +++++++++++++++---- .../apache/spark/ml/param/ParamsSuite.scala | 10 ++++++++ .../spark/ml/recommendation/ALSSuite.scala | 7 ++++++ 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 50c0d855066f8..21be83a1cccb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -134,6 +134,14 @@ object ParamValidators { getDouble(value) <= upperBound } + /** Check if value == requiredValue */ + def eq[T](requiredValue: T): T => Boolean = { + case value @ (_: Double | _: Float) => + throw new IllegalArgumentException("ParamValidator.eq not intended for real numbers, " + + s"value given is $value") + case value: Any => value == requiredValue + } + /** * Check for value in range lowerBound to upperBound. * @param lowerInclusive If true, check for value >= lowerBound. @@ -166,6 +174,11 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Use two validators together in a logical OR expression */ + def or[T](a: T => Boolean, b: T => Boolean): T => Boolean = { (value: T) => + a(value) || b(value) + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a22..a7536a9deb232 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -88,23 +88,23 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w def getRank: Int = $(rank) /** - * Param for number of user blocks (>= 1). + * Param for number of user blocks (>= 1, or -1 to auto-configure). * Default: 10 * @group param */ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", - ParamValidators.gtEq(1)) + ParamValidators.or(ParamValidators.gtEq(1), ParamValidators.eq(-1))) /** @group getParam */ def getNumUserBlocks: Int = $(numUserBlocks) /** - * Param for number of item blocks (>= 1). + * Param for number of item blocks (>= 1, or -1 to auto-configure). * Default: 10 * @group param */ val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", - ParamValidators.gtEq(1)) + ParamValidators.or(ParamValidators.gtEq(1), ParamValidators.eq(-1))) /** @group getParam */ def getNumItemBlocks: Int = $(numItemBlocks) @@ -321,8 +321,21 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } + + // check if num blocks should be auto-configured + val numUserBlocksFinal = if ($(numUserBlocks) == -1) { + math.max(dataset.sqlContext.sparkContext.defaultParallelism, ratings.partitions.size / 2) + } else { + $(numUserBlocks) + } + val numItemBlocksFinal = if ($(numItemBlocks) == -1) { + math.max(dataset.sqlContext.sparkContext.defaultParallelism, ratings.partitions.size / 2) + } else { + $(numItemBlocks) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), - numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), + numUserBlocks = numUserBlocksFinal, numItemBlocks = numItemBlocksFinal, maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea017..2066359b923cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -181,6 +181,13 @@ class ParamsSuite extends SparkFunSuite { val ltEq1Double = ParamValidators.ltEq[Double](1) assert(ltEq1Double(1.0) && !ltEq1Double(1.1)) + val eq1Int = ParamValidators.eq[Int](1) + assert(eq1Int(1) && !eq1Int(0)) + val eqDouble = ParamValidators.eq[Double](2/3.0) + intercept[IllegalArgumentException] { + eqDouble(1 - 1/3.0) + } + val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2) assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) && !inRange02IntInclusive(-1) && !inRange02IntInclusive(3)) @@ -199,6 +206,9 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) + + val orInt = ParamValidators.or(ParamValidators.eq[Int](-1), gtEq1Int) + assert(orInt(-1) && orInt(1) && !orInt(0)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb6..13a3142c0cd31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -412,6 +412,13 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { numItemBlocks = 5, numUserBlocks = 5) } + test("auto-configured block settings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numUserBlocks = -1, numItemBlocks = -1) + } + test("implicit feedback") { val (training, test) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)