From 0af7c41e129afb630b351dea5f409ddc0c6d01ff Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Wed, 21 Jan 2015 17:06:25 +0000 Subject: [PATCH 01/22] SPARK-5309: Use Dictionary for Binary->String conversion Use of Dictionary will reduce conversion overhead which is significant when converting from Binary in UTF to Java String. Dictionaries are use in Parquet where columns hold a limited set of values which is often the case. --- .../spark/sql/parquet/ParquetConverter.scala | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index b4aed04199129..d8038c3b1f0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.parquet + +import parquet.column.Dictionary + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -102,12 +105,8 @@ private[sql] object CatalystConverter { } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately - case StringType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value) - } - } + case StringType => + new CatalystPrimitiveStringConverter(parent, fieldIndex) case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = @@ -197,8 +196,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + updateField(fieldIndex, value) protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) @@ -384,8 +383,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - current.setString(fieldIndex, value.toStringUsingUTF8) + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + current.setString(fieldIndex, value) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -426,6 +425,33 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. + * Supports dictionaries to reduce Binary to String conversion overhead. + * + * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) + extends CatalystPrimitiveConverter(parent, fieldIndex) { + + private var dict: Array[String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary):Unit = + dict = (for (i <- 0 to dictionary.getMaxId) + yield dictionary.decodeToBinary(i).toStringUsingUTF8).toArray + + override def addValueFromDictionary(dictionaryId: Int): Unit = + parent.updateString(fieldIndex, dict(dictionaryId)) + + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value.toStringUsingUTF8) +} + private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } @@ -583,9 +609,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { checkGrowBuffer() - buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + buffer(elements) = value.asInstanceOf[NativeType] elements += 1 } From aa1e22b17b4ce885febe6970a2451c7d17d0acfb Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Wed, 21 Jan 2015 09:48:38 -0800 Subject: [PATCH 02/22] [MLlib] [SPARK-5301] Missing conversions and operations on IndexedRowMatrix and CoordinateMatrix * Transpose is missing from CoordinateMatrix (this is cheap to compute, so it should be there) * IndexedRowMatrix should be convertable to CoordinateMatrix (conversion added) Tests for both added. Author: Reza Zadeh Closes #4089 from rezazadeh/matutils and squashes the following commits: ec5238b [Reza Zadeh] Array -> Iterator to avoid temp array 3ce0b5d [Reza Zadeh] Array -> Iterator bbc907a [Reza Zadeh] Use 'i' for index, and zipWithIndex cb10ae5 [Reza Zadeh] remove unnecessary import a7ae048 [Reza Zadeh] Missing linear algebra utilities --- .../linalg/distributed/CoordinateMatrix.scala | 5 +++++ .../linalg/distributed/IndexedRowMatrix.scala | 17 +++++++++++++++++ .../distributed/CoordinateMatrixSuite.scala | 5 +++++ .../distributed/IndexedRowMatrixSuite.scala | 8 ++++++++ 4 files changed, 35 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 06d8915f3bfa1..b60559c853a50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -69,6 +69,11 @@ class CoordinateMatrix( nRows } + /** Transposes this CoordinateMatrix. */ + def transpose(): CoordinateMatrix = { + new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) + } + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 181f507516485..c518271f04729 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,23 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** + * Converts this matrix to a + * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. + */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entries = rows.flatMap { row => + val rowIndex = row.index + row.vector match { + case SparseVector(size, indices, values) => + Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i))) + case DenseVector(values) => + Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i))) + } + } + new CoordinateMatrix(entries, numRows(), numCols()) + } + /** * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f8709751efce6..80bef814ce50d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(mat.toBreeze() === expected) } + test("transpose") { + val transposed = mat.transpose() + assert(mat.toBreeze().t === transposed.toBreeze()) + } + test("toIndexedRowMatrix") { val indexedRowMatrix = mat.toIndexedRowMatrix() val expected = BDM( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 741cd4997b853..b86c2ca5ff136 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -80,6 +80,14 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq) } + test("toCoordinateMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val coordMat = idxRowMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === idxRowMat.toBreeze()) + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) From 7450a992b3b543a373c34fc4444a528954ac4b4a Mon Sep 17 00:00:00 2001 From: "nate.crosswhite" Date: Wed, 21 Jan 2015 10:32:10 -0800 Subject: [PATCH 03/22] [SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark Author: nate.crosswhite Author: nxwhite-str Author: Xiangrui Meng Closes #3610 from nxwhite-str/master and squashes the following commits: a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed 7668124 [Xiangrui Meng] minor updates f8d5928 [nate.crosswhite] Addressing PR issues 277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test 616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API --- .../mllib/api/python/PythonMLLibAPI.scala | 6 ++- .../spark/mllib/clustering/KMeans.scala | 48 +++++++++++++++---- .../spark/mllib/clustering/KMeansSuite.scala | 21 ++++++++ python/pyspark/mllib/clustering.py | 4 +- python/pyspark/mllib/tests.py | 17 ++++++- 5 files changed, 84 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 555da8c7e7ab3..430d763ef7ca7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable { k: Int, maxIterations: Int, runs: Int, - initializationMode: String): KMeansModel = { + initializationMode: String, + seed: java.lang.Long): KMeansModel = { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) + + if (seed != null) kMeansAlg.setSeed(seed) + try { kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) } finally { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 54c301d3e9e14..6b5c934f015ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** @@ -43,13 +43,14 @@ class KMeans private ( private var runs: Int, private var initializationMode: String, private var initializationSteps: Int, - private var epsilon: Double) extends Serializable with Logging { + private var epsilon: Double, + private var seed: Long) extends Serializable with Logging { /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, - * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. + * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { @@ -112,6 +113,12 @@ class KMeans private ( this } + /** Set the random seed for cluster initialization. */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -255,7 +262,7 @@ class KMeans private ( private def initRandom(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v => new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm) }.toArray) @@ -273,7 +280,7 @@ class KMeans private ( private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { // Initialize each run's center to a random point - val seed = new XORShiftRandom().nextInt() + val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) @@ -333,7 +340,32 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Array[Double]]` + * @param data training points stored as `RDD[Vector]` + * @param k number of clusters + * @param maxIterations max number of iterations + * @param runs number of parallel runs, defaults to 1. The best model is returned. + * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param seed random seed value for cluster initialization + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String, + seed: Long): KMeansModel = { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .setSeed(seed) + .run(data) + } + + /** + * Trains a k-means model using the given set of parameters. + * + * @param data training points stored as `RDD[Vector]` * @param k number of clusters * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 9ebef8466c831..caee5917000aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(model.clusterCenters.size === 3) } + test("deterministic initialization") { + // Create a large-ish set of points for clustering + val points = List.tabulate(1000)(n => Vectors.dense(n, n)) + val rdd = sc.parallelize(points, 3) + + for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) { + // Create three deterministic models and compare cluster means + val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers1 = model1.clusterCenters + + val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1, + initializationMode = initMode, seed = 42) + val centers2 = model2.clusterCenters + + centers1.zip(centers2).foreach { case (c1, c2) => + assert(c1 ~== c2 absTol 1E-14) + } + } + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5bd6a..6b713aa39374e 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -78,10 +78,10 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 140c22b5fd4e8..f48e3d6dacb4b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -140,7 +140,7 @@ class ListTests(PySparkTestCase): as NumPy arrays. """ - def test_clustering(self): + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ [0, 1.1], @@ -152,6 +152,21 @@ def test_clustering(self): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree From 3ee3ab592eee831d759c940eb68231817ad6d083 Mon Sep 17 00:00:00 2001 From: Kenji Kikushima Date: Wed, 21 Jan 2015 12:34:00 -0800 Subject: [PATCH 04/22] [SPARK-5064][GraphX] Add numEdges upperbound validation for R-MAT graph generator to prevent infinite loop I looked into GraphGenerators#chooseCell, and found that chooseCell can't generate more edges than pow(2, (2 * (log2(numVertices)-1))) to make a Power-law graph. (Ex. numVertices:4 upperbound:4, numVertices:8 upperbound:16, numVertices:16 upperbound:64) If we request more edges over the upperbound, rmatGraph fall into infinite loop. So, how about adding an argument validation? Author: Kenji Kikushima Closes #3950 from kj-ki/SPARK-5064 and squashes the following commits: 4ee18c7 [Ankur Dave] Reword error message and add unit test d760bc7 [Kenji Kikushima] Add numEdges upperbound validation for R-MAT graph generator to prevent infinite loop. --- .../org/apache/spark/graphx/util/GraphGenerators.scala | 6 ++++++ .../spark/graphx/util/GraphGeneratorsSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 8a13c74221546..2d6a825b61726 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -133,6 +133,12 @@ object GraphGenerators { // This ensures that the 4 quadrants are the same size at all recursion levels val numVertices = math.round( math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt + val numEdgesUpperBound = + math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt + if (numEdgesUpperBound < numEdges) { + throw new IllegalArgumentException( + s"numEdges must be <= $numEdgesUpperBound but was $numEdges") + } var edges: Set[Edge[Int]] = Set() while (edges.size < numEdges) { if (edges.size % 100 == 0) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 3abefbe52fa8a..8d9c8ddccbb3c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { } } + test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") { + withSpark { sc => + val g1 = GraphGenerators.rmatGraph(sc, 4, 4) + assert(g1.edges.count() === 4) + intercept[IllegalArgumentException] { + val g2 = GraphGenerators.rmatGraph(sc, 4, 8) + } + } + } + } From 812d3679f5f97df7b667cbc3365a49866ebc02d5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 21 Jan 2015 12:59:41 -0800 Subject: [PATCH 05/22] [SPARK-5244] [SQL] add coalesce() in sql parser Author: Daoyuan Wang Closes #4040 from adrian-wang/coalesce and squashes the following commits: 0ac8e8f [Daoyuan Wang] add coalesce() in sql parser --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0b36d8b9bfce5..388e2f74a0ecb 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -51,6 +51,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -306,6 +307,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } + | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..64648bad385e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -86,6 +86,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + test("Add Parser of SQL COALESCE()") { + checkAnswer( + sql("""SELECT COALESCE(1, 2)"""), + 1) + checkAnswer( + sql("SELECT COALESCE(null, null, null)"), + null) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), From 8361078efae7d79742d6be94cf5a15637ec860dd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 21 Jan 2015 13:05:56 -0800 Subject: [PATCH 06/22] [SPARK-5009] [SQL] Long keyword support in SQL Parsers * The `SqlLexical.allCaseVersions` will cause `StackOverflowException` if the key word is too long, the patch will fix that by normalizing all of the keywords in `SqlLexical`. * And make a unified SparkSQLParser for sharing the common code. Author: Cheng Hao Closes #3926 from chenghao-intel/long_keyword and squashes the following commits: 686660f [Cheng Hao] Support Long Keyword and Refactor the SQLParsers --- .../sql/catalyst/AbstractSparkSQLParser.scala | 59 +++++++++++++----- .../apache/spark/sql/catalyst/SqlParser.scala | 15 +---- .../spark/sql/catalyst/SqlParserSuite.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SparkSQLParser.scala | 15 +---- .../org/apache/spark/sql/sources/ddl.scala | 39 +++++------- .../spark/sql/hive/ExtendedHiveQlParser.scala | 16 +---- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 8 files changed, 128 insertions(+), 81 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 93d74adbcc957..366be00473d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -25,15 +25,42 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ +private[sql] object KeywordNormalizer { + def apply(str: String) = str.toLowerCase() +} + private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) + def apply(input: String): LogicalPlan = { + // Initialize the Keywords. + lexical.initialize(reservedWords) + phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } } - protected case class Keyword(str: String) + protected case class Keyword(str: String) { + def normalize = KeywordNormalizer(str) + def parser: Parser[String] = normalize + } + + protected implicit def asParser(k: Keyword): Parser[String] = k.parser + + // By default, use Reflection to find the reserved words defined in the sub class. + // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this + // method during the parent class instantiation, because the sub class instance + // isn't created yet. + protected lazy val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].normalize) + + // Set the keywords as empty by default, will change that later. + override val lexical = new SqlLexical protected def start: Parser[LogicalPlan] @@ -52,18 +79,27 @@ private[sql] abstract class AbstractSparkSQLParser } } -class SqlLexical(val keywords: Seq[String]) extends StdLexical { +class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { override def toString = chars } - reserved ++= keywords.flatMap(w => allCaseVersions(w)) + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" ) + protected override def processIdent(name: String) = { + val token = KeywordNormalizer(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } @@ -94,14 +130,5 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | '-' ~ '-' ~ chrExcept(EofCh, '\n').* | '/' ~ '*' ~ failure("unclosed comment") ).* - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s.isEmpty) { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) #::: - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 388e2f74a0ecb..4ca4e05edd460 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -36,9 +36,8 @@ import org.apache.spark.sql.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -108,16 +107,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { case (ne: NamedExpression, _) => ne diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala new file mode 100644 index 0000000000000..1a0a0e6154ad2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Command +import org.scalatest.FunSuite + +private[sql] case class TestCommand(cmd: String) extends Command + +private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("EXECUTE") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +class SqlParserSuite extends FunSuite { + + test("test long keyword") { + val parser = new SuperLongKeywordTestParser + assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + } + + test("test case insensitive") { + val parser = new CaseInsensitiveTestParser + assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f23cb18c92d5d..0a22968cc7807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -107,7 +107,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql).getOrElse(sqlParser(sql)) + ddlParser(sql, false).getOrElse(sqlParser(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index f10ee7b66feb7..f1a4053b79113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql + import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} @@ -61,18 +62,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLE = Keyword("TABLE") protected val UNCACHE = Keyword("UNCACHE") - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - private val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others private lazy val cache: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 381298caba6f2..171b816a26332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,32 +18,32 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging import org.apache.spark.sql.{SchemaRDD, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + /** * A parser for foreign DDL commands. */ -private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { - - def apply(input: String): Option[LogicalPlan] = { - phrase(ddl)(new lexical.Scanner(input)) match { - case Success(r, x) => Some(r) - case x => - logDebug(s"Not recognized as DDL: $x") - None +private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { + + def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + try { + Some(apply(input)) + } catch { + case _ if !exceptionOnError => None + case x: Throwable => throw x } } def parseType(input: String): DataType = { + lexical.initialize(reservedWords) phrase(dataType)(new lexical.Scanner(input)) match { case Success(r, x) => r case x => @@ -51,11 +51,9 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi } } - protected case class Keyword(str: String) - - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") @@ -80,17 +78,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected val MAP = Keyword("MAP") protected val STRUCT = Keyword("STRUCT") - // Use reflection to find the reserved words defined in this class. - protected val reservedWords = - this.getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected def start: Parser[LogicalPlan] = ddl + /** * `CREATE [TEMPORARY] TABLE avroTable * USING org.apache.spark.sql.avro diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index ebf7003ff9e57..3f20c6142e59a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -20,30 +20,20 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object protected val ADD = Keyword("ADD") protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") protected val JAR = Keyword("JAR") - private val reservedWords = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - - override val lexical = new SqlLexical(reservedWords) - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3e26fe3675768..274f83af5ac03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -70,7 +70,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (conf.dialect == "sql") { super.sql(sqlText) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } From b328ac6c8c489ef9abf850c45db5ad531da18d55 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 21 Jan 2015 14:27:43 -0800 Subject: [PATCH 07/22] Revert "[SPARK-5244] [SQL] add coalesce() in sql parser" This reverts commit 812d3679f5f97df7b667cbc3365a49866ebc02d5. --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 -- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 --------- 2 files changed, 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 4ca4e05edd460..eaadbe9fd5099 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -50,7 +50,6 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -296,7 +295,6 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 64648bad385e7..03b44ca1d6695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -86,15 +86,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } - test("Add Parser of SQL COALESCE()") { - checkAnswer( - sql("""SELECT COALESCE(1, 2)"""), - 1) - checkAnswer( - sql("SELECT COALESCE(null, null, null)"), - null) - } - test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), From ba19689fe77b90052b587640c9ff325c5a892c20 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 21 Jan 2015 14:38:10 -0800 Subject: [PATCH 08/22] [SQL] [Minor] Remove deprecated parquet tests This PR removes the deprecated `ParquetQuerySuite`, renamed `ParquetQuerySuite2` to `ParquetQuerySuite`, and refactored changes introduced in #4115 to `ParquetFilterSuite` . It is a follow-up of #3644. Notice that test cases in the old `ParquetQuerySuite` have already been well covered by other test suites introduced in #3644. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4116) Author: Cheng Lian Closes #4116 from liancheng/remove-deprecated-parquet-tests and squashes the following commits: f73b8f9 [Cheng Lian] Removes deprecated Parquet test suite --- .../sql/parquet/ParquetFilterSuite.scala | 373 +++--- .../spark/sql/parquet/ParquetQuerySuite.scala | 1040 +---------------- .../sql/parquet/ParquetQuerySuite2.scala | 88 -- 3 files changed, 212 insertions(+), 1289 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4ad8c472007fc..1e7d3e06fc196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,7 +21,7 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} @@ -40,15 +40,16 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - private def checkFilterPushdown( + private def checkFilterPredicate( rdd: SchemaRDD, - output: Seq[Symbol], predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Any) => Unit, - expectedResult: Any): Unit = { + checker: (SchemaRDD, Seq[Row]) => Unit, + expected: Seq[Row]): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output.map(_.attr): _*).where(predicate) + val query = rdd.select(output: _*).where(predicate) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -58,209 +59,180 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { maybeAnalyzedPredicate.foreach { pred => val maybeFilter = ParquetFilters.createFilter(pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach(f => assert(f.getClass === filterClass)) + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } - checker(query, expectedResult) + checker(query, expected) } } - private def checkFilterPushdown1 - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Seq[Row]): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, - (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), expectedResult) - } - - private def checkFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: Int): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, - (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), Seq(Row(expectedResult))) + private def checkFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } - def checkBinaryFilterPushdown - (rdd: SchemaRDD, output: Symbol*) - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, result: Any): Unit = { - val actual = rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq - val expected = result match { - case s: Seq[_] => s.map(_.asInstanceOf[Row].getAs[Array[Byte]](0).mkString(",")) - case s => Seq(s.asInstanceOf[Array[Byte]].mkString(",")) - } - assert(actual.sorted === expected.sorted) - } - checkFilterPushdown(rdd, output, predicate, filterClass, checkBinaryAnswer _, expectedResult) + private def checkFilterPredicate[T] + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) + (implicit rdd: SchemaRDD): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { - Seq(Row(true), Row(false)) - } + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPushdown1(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(Seq(Row(true))) - checkFilterPushdown1(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(Seq(Row(false))) + checkFilterPredicate('_1 === true, classOf[Eq [_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [Integer]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[Integer]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Long]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq [Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Float]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { - (1 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) - checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { - (2 to 4).map(Row.apply(_)) - } - - checkFilterPushdown(rdd, '_1)('_1 < 2, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 > 3, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 <= 1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 >= 4, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(Literal(1) === '_1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)(Literal(2) > '_1, classOf[Lt [java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(3) < '_1, classOf[Gt [java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)(Literal(1) >= '_1, classOf[LtEq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)(Literal(4) <= '_1, classOf[GtEq[java.lang.Double]])(4) - - checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4) - checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { - Seq(Row(1), Row(4)) - } + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => - checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.toString)) - } - - checkFilterPushdown1(rdd, '_1)('_1 === "1", classOf[Eq[String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { - (2 to 4).map(i => Row.apply(i.toString)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } + } - checkFilterPushdown1(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - - checkFilterPushdown1(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])(Seq(Row("1"))) - checkFilterPushdown1(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - - checkFilterPushdown1(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown1(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])(Seq(Row("3"))) - checkFilterPushdown1(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { - Seq(Row("1"), Row("4")) + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: SchemaRDD): Unit = { + def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } + + checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) + } + + def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit rdd: SchemaRDD): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - binary") { @@ -268,33 +240,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def b: Array[Byte] = int.toString.getBytes("UTF-8") } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd => - checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { - (1 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) { - (2 to 4).map(i => Row.apply(i.b)).toSeq - } - - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 3.b, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 <= 1.b, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 >= 4.b, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) === '_1, classOf[Eq [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(2.b) > '_1, classOf[Lt [Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(3.b) < '_1, classOf[Gt [Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(1.b) >= '_1, classOf[LtEq[Array[Byte]]])(1.b) - checkBinaryFilterPushdown(rdd, '_1)(Literal(4.b) <= '_1, classOf[GtEq[Array[Byte]]])(4.b) - - checkBinaryFilterPushdown(rdd, '_1)(!('_1 < 4.b), classOf[Operators.GtEq[Array[Byte]]])(4.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 > 2.b && '_1 < 4.b, classOf[Operators.And])(3.b) - checkBinaryFilterPushdown(rdd, '_1)('_1 < 2.b || '_1 > 3.b, classOf[Operators.Or]) { - Seq(Row(1.b), Row(4.b)) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 2c5345b1f9148..1263ff818ea19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,1030 +17,72 @@ package org.apache.spark.sql.parquet -import scala.reflect.ClassTag - -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import parquet.filter2.predicate.{FilterPredicate, Operators} -import parquet.hadoop.ParquetFileWriter -import parquet.hadoop.util.ContextUtil -import parquet.io.api.Binary - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{Row => _, _} -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -case class TestRDDEntry(key: Int, value: String) - -case class NullReflectData( - intField: java.lang.Integer, - longField: java.lang.Long, - floatField: java.lang.Float, - doubleField: java.lang.Double, - booleanField: java.lang.Boolean) - -case class OptionalReflectData( - intField: Option[Int], - longField: Option[Long], - floatField: Option[Float], - doubleField: Option[Double], - booleanField: Option[Boolean]) - -case class Nested(i: Int, s: String) - -case class Data(array: Seq[Int], nested: Nested) - -case class AllDataTypes( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean) - -case class AllDataTypesWithNonPrimitiveType( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean, - array: Seq[Int], - arrayContainsNull: Seq[Option[Int]], - map: Map[Int, Long], - mapValueContainsNull: Map[Int, Option[Long]], - data: Data) - -case class BinaryData(binaryData: Array[Byte]) - -case class NumericData(i: Int, d: Double) -class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - TestData // Load test data tables. - - private var testRDD: SchemaRDD = null - private val originalParquetFilterPushdownEnabled = TestSQLContext.conf.parquetFilterPushDown - - override def beforeAll() { - ParquetTestData.writeFile() - ParquetTestData.writeFilterFile() - ParquetTestData.writeNestedFile1() - ParquetTestData.writeNestedFile2() - ParquetTestData.writeNestedFile3() - ParquetTestData.writeNestedFile4() - ParquetTestData.writeGlobFiles() - testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerTempTable("testsource") - parquetFile(ParquetTestData.testFilterDir.toString) - .registerTempTable("testfiltersource") - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, "true") - } - - override def afterAll() { - Utils.deleteRecursively(ParquetTestData.testDir) - Utils.deleteRecursively(ParquetTestData.testFilterDir) - Utils.deleteRecursively(ParquetTestData.testNestedDir1) - Utils.deleteRecursively(ParquetTestData.testNestedDir2) - Utils.deleteRecursively(ParquetTestData.testNestedDir3) - Utils.deleteRecursively(ParquetTestData.testNestedDir4) - Utils.deleteRecursively(ParquetTestData.testGlobDir) - // here we should also unregister the table?? - - setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, originalParquetFilterPushdownEnabled.toString) - } +/** + * A test suite that tests various Parquet queries. + */ +class ParquetQuerySuite extends QueryTest with ParquetTest { + val sqlContext = TestSQLContext - test("Read/Write All Types") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypes( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0) + test("simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) } - - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("read/write binary data") { - // Since equality for Array[Byte] is broken we test this separately. - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) - parquetFile(tempDir) - .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) - .collect().toSeq == Seq("test") - } - - ignore("Treat binary as string") { - val oldIsParquetBinaryAsString = TestSQLContext.conf.isParquetBinaryAsString - - // Create the test file. - val file = getTempFilePath("parquet") - val path = file.toString - val range = (0 to 255) - val rowRDD = TestSQLContext.sparkContext.parallelize(range) - .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) - // We need to ask Parquet to store the String column as a Binary column. - val schema = StructType( - StructField("c1", IntegerType, false) :: - StructField("c2", BinaryType, false) :: Nil) - val schemaRDD1 = applySchema(rowRDD, schema) - schemaRDD1.saveAsParquetFile(path) - checkAnswer( - parquetFile(path).select('c1, 'c2.cast(StringType)), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - parquetFile(path).printSchema() - checkAnswer( - parquetFile(path), - schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) - } - - test("Compression options for writing to a Parquetfile") { - val defaultParquetCompressionCodec = TestSQLContext.conf.parquetCompressionCodec - import scala.collection.JavaConversions._ - - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - - // test default compression codec - rdd.saveAsParquetFile(path) - var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "UNCOMPRESSED" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "UNCOMPRESSED") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test uncompressed parquet file with property value "none" - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "none") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === "UNCOMPRESSED" :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test gzip compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "gzip") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // test snappy compression codec - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, "snappy") - - rdd.saveAsParquetFile(path) - actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) - .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) - - parquetFile(path).registerTempTable("tmp") - checkAnswer( - sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - Row(5, "val_5") :: - Row(7, "val_7") :: Nil) - - Utils.deleteRecursively(file) - - // TODO: Lzo requires additional external setup steps so leave it out for now - // ref.: https://github.com/Parquet/parquet-mr/blob/parquet-1.5.0/parquet-hadoop/src/test/java/parquet/hadoop/example/TestInputOutputFormat.java#L169 - - // Set it back. - TestSQLContext.setConf(SQLConf.PARQUET_COMPRESSION, defaultParquetCompressionCodec) - } - - test("Read/Write All Types with non-primitive type") { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range).map { x => - parquet.AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 until x), - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - parquet.Data((0 until x), parquet.Nested(x, s"$x"))) + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - data.saveAsParquetFile(tempDir) - - checkAnswer( - parquetFile(tempDir), - data.toSchemaRDD.collect().toSeq) } - test("self-join parquet files") { - val x = ParquetTestData.testData.as('x) - val y = ParquetTestData.testData.as('y) - val query = x.join(y).where("x.myint".attr === "y.myint".attr) - - // Check to make sure that the attributes from either side of the join have unique expression - // ids. - query.queryExecution.analyzed.output.filter(_.name == "myint") match { - case Seq(i1, i2) if(i1.exprId == i2.exprId) => - fail(s"Duplicate expression IDs found in query plan: $query") - case Seq(_, _) => // All good + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) } - val result = query.collect() - assert(result.size === 9, "self-join result has incorrect size") - assert(result(0).size === 12, "result row has incorrect size") - result.zipWithIndex.foreach { - case (row, index) => row.toSeq.zipWithIndex.foreach { - case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") - } - } - } + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - test("Import of simple Parquet file") { - val result = parquetFile(ParquetTestData.testDir.toString).collect() - assert(result.size === 15) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if (index % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") + assertResult(4, s"Field count mismatches")(queryOutput.size) + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size } - } - } - test("Projection of simple Parquet file") { - val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() - result.zipWithIndex.foreach { - case (row, index) => { - if (index % 3 == 0) - assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") - else - assert(row(0) === false, s"boolean field value in line $index did not match") - assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") - assert(row.size === 2, s"number of columns in projection in line $index is incorrect") - } + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) } } - test("Writing metadata from scratch for table CREATE") { - val job = new Job() - val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString) - val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job)) - ParquetTypesConverter.writeMetaData( - ParquetTestData.testData.output, - path, - TestSQLContext.sparkContext.hadoopConfiguration) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) - assert(metaData != null) - ParquetTestData - .testData - .parquetSchema - .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible - metaData - .getFileMetaData - .getSchema - .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible - fs.delete(path, true) - } - - test("Creating case class RDD table") { - TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - .registerTempTable("tmp") - val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) - var counter = 1 - rdd.foreach { - // '===' does not like string comparison? - row: Row => { - assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") - counter = counter + 1 - } + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) } } - test("Read a parquet file instead of a directory") { - val file = getTempFilePath("parquet") - val path = file.toString - val fsPath = new Path(path) - val fs: FileSystem = fsPath.getFileSystem(TestSQLContext.sparkContext.hadoopConfiguration) - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.coalesce(1).saveAsParquetFile(path) - - val children = fs.listStatus(fsPath).filter(_.getPath.getName.endsWith(".parquet")) - assert(children.length > 0) - val readFile = parquetFile(path + "/" + children(0).getPath.getName) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) } - Utils.deleteRecursively(file) - } - - test("Insert (appending) to same table via Scala API") { - sql("INSERT INTO testsource SELECT * FROM testsource") - val double_rdd = sql("SELECT * FROM testsource").collect() - assert(double_rdd != null) - assert(double_rdd.size === 30) - - // let's restore the original test data - Utils.deleteRecursively(ParquetTestData.testDir) - ParquetTestData.writeFile() - } - - test("save and load case class RDD with nulls as parquet") { - val data = parquet.NullReflectData(null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Row(null, null, null, null, null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("save and load case class RDD with Nones as parquet") { - val data = parquet.OptionalReflectData(None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - - val file = getTempFilePath("parquet") - val path = file.toString - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - - val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Row(null, null, null, null, null)) - Utils.deleteRecursively(file) - assert(true) - } - - test("make RecordFilter for simple predicates") { - def checkFilter[T <: FilterPredicate : ClassTag]( - predicate: Expression, - defined: Boolean = true): Unit = { - val filter = ParquetFilters.createFilter(predicate) - if (defined) { - assert(filter.isDefined) - val tClass = implicitly[ClassTag[T]].runtimeClass - val filterGet = filter.get - assert( - tClass.isInstance(filterGet), - s"$filterGet of type ${filterGet.getClass} is not an instance of $tClass") - } else { - assert(filter.isEmpty) - } - } - - checkFilter[Operators.Eq[Integer]]('a.int === 1) - checkFilter[Operators.Eq[Integer]](Literal(1) === 'a.int) - - checkFilter[Operators.Lt[Integer]]('a.int < 4) - checkFilter[Operators.Lt[Integer]](Literal(4) > 'a.int) - checkFilter[Operators.LtEq[Integer]]('a.int <= 4) - checkFilter[Operators.LtEq[Integer]](Literal(4) >= 'a.int) - - checkFilter[Operators.Gt[Integer]]('a.int > 4) - checkFilter[Operators.Gt[Integer]](Literal(4) < 'a.int) - checkFilter[Operators.GtEq[Integer]]('a.int >= 4) - checkFilter[Operators.GtEq[Integer]](Literal(4) <= 'a.int) - - checkFilter[Operators.And]('a.int === 1 && 'a.int < 4) - checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4) - checkFilter[Operators.NotEq[Integer]](!('a.int === 1)) - - checkFilter('a.int > 'b.int, defined = false) - checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false) - } - - test("test filter by predicate pushdown") { - for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result1 = query1.collect() - assert(result1.size === 50) - assert(result1(0)(1) === 100) - assert(result1(49)(1) === 149) - val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result2 = query2.collect() - assert(result2.size === 50) - if (myval == "myint" || myval == "mylong") { - assert(result2(0)(1) === 151) - assert(result2(49)(1) === 200) - } else { - assert(result2(0)(1) === 150) - assert(result2(49)(1) === 199) - } - } - for(myval <- Seq("myint", "mylong")) { - val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10") - assert( - query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result3 = query3.collect() - assert(result3.size === 20) - assert(result3(0)(1) === 0) - assert(result3(9)(1) === 9) - assert(result3(10)(1) === 191) - assert(result3(19)(1) === 200) - } - for(myval <- Seq("mydouble", "myfloat")) { - val result4 = - if (myval == "mydouble") { - val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0") - assert( - query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - query4.collect() - } else { - // CASTs are problematic. Here myfloat will be casted to a double and it seems there is - // currently no way to specify float constants in SqlParser? - sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() - } - assert(result4.size === 20) - assert(result4(0)(1) === 0) - assert(result4(9)(1) === 9) - assert(result4(10)(1) === 191) - assert(result4(19)(1) === 200) - } - val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") - assert( - query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val booleanResult = query5.collect() - assert(booleanResult.size === 10) - for(i <- 0 until 10) { - if (!booleanResult(i).getBoolean(0)) { - fail(s"Boolean value in result row $i not true") - } - if (booleanResult(i).getInt(1) != i * 4) { - fail(s"Int value in result row $i should be ${4*i}") - } - } - val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"") - assert( - query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val stringResult = query6.collect() - assert(stringResult.size === 1) - assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") - assert(stringResult(0).getInt(1) === 100) - - val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") - assert( - query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val optResult = query7.collect() - assert(optResult.size === 20) - for(i <- 0 until 20) { - if (optResult(i)(7) != i * 2) { - fail(s"optional Int value in result row $i should be ${2*4*i}") - } - } - for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { - val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") - assert( - query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result8 = query8.collect() - assert(result8.size === 25) - assert(result8(0)(7) === 100) - assert(result8(24)(7) === 148) - val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") - assert( - query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result9 = query9.collect() - assert(result9.size === 25) - if (myval == "myoptint" || myval == "myoptlong") { - assert(result9(0)(7) === 152) - assert(result9(24)(7) === 200) - } else { - assert(result9(0)(7) === 150) - assert(result9(24)(7) === 198) - } - } - val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") - assert( - query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result10 = query10.collect() - assert(result10.size === 1) - assert(result10(0).getString(8) == "100", "stringvalue incorrect") - assert(result10(0).getInt(7) === 100) - val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") - assert( - query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result11 = query11.collect() - assert(result11.size === 7) - for(i <- 0 until 6) { - if (!result11(i).getBoolean(6)) { - fail(s"optional Boolean value in result row $i not true") - } - if (result11(i).getInt(7) != i * 6) { - fail(s"optional Int value in result row $i should be ${6*i}") - } - } - - val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") - assert( - query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result12 = query12.collect() - assert(result12.size === 54) - assert(result12(0).getString(2) == "6") - assert(result12(4).getString(2) == "50") - assert(result12(53).getString(2) == "99") - - val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") - assert( - query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result13 = query13.collect() - assert(result13.size === 53) - assert(result13(0).getString(2) == "6") - assert(result13(4).getString(2) == "51") - assert(result13(52).getString(2) == "99") - - val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") - assert( - query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result14 = query14.collect() - assert(result14.size === 148) - assert(result14(0).getString(2) == "0") - assert(result14(46).getString(2) == "50") - assert(result14(147).getString(2) == "200") - - val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") - assert( - query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], - "Top operator should be ParquetTableScan after pushdown") - val result15 = query15.collect() - assert(result15.size === 147) - assert(result15(0).getString(2) == "0") - assert(result15(46).getString(2) == "100") - assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") - assert(query.collect().size === 10) - } - - test("Importing nested Parquet file (Addressbook)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - .collect() - assert(result != null) - assert(result.size === 2) - val first_record = result(0) - val second_record = result(1) - assert(first_record != null) - assert(second_record != null) - assert(first_record.size === 3) - assert(second_record(1) === null) - assert(second_record(2) === null) - assert(second_record(0) === "A. Nonymous") - assert(first_record(0) === "Julien Le Dem") - val first_owner_numbers = first_record(1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - val first_contacts = first_record(2) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(first_owner_numbers != null) - assert(first_owner_numbers(0) === "555 123 4567") - assert(first_owner_numbers(2) === "XXX XXX XXXX") - assert(first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) - val first_contacts_entry_one = first_contacts(0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") - assert(first_contacts_entry_one(1) === "555 987 6543") - val first_contacts_entry_two = first_contacts(1) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(first_contacts_entry_two(0) === "Chris Aniszczyk") - } - - test("Importing nested Parquet file (nested numbers)") { - val result = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD - .collect() - assert(result.size === 1, "number of top-level rows incorrect") - assert(result(0).size === 5, "number of fields in row incorrect") - assert(result(0)(0) === 1) - assert(result(0)(1) === 7) - val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult1.size === 3) - assert(subresult1(0) === (1.toLong << 32)) - assert(subresult1(1) === (1.toLong << 33)) - assert(subresult1(2) === (1.toLong << 34)) - val subresult2 = result(0)(3) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult2.size === 2) - assert(subresult2(0) === 2.5) - assert(subresult2(1) === false) - val subresult3 = result(0)(4) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult3.size === 2) - assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) - val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) - assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("Simple query on addressbook") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD - val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() - assert(tmp.size === 1) - assert(tmp(0)(0) === "Julien Le Dem") - } - - test("Projection in addressbook") { - val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - data.registerTempTable("data") - val query = sql("SELECT owner, contacts[1].name FROM data") - val tmp = query.collect() - assert(tmp.size === 2) - assert(tmp(0).size === 2) - assert(tmp(0)(0) === "Julien Le Dem") - assert(tmp(0)(1) === "Chris Aniszczyk") - assert(tmp(1)(0) === "A. Nonymous") - assert(tmp(1)(1) === null) - } - - test("Simple query on nested int data") { - val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT entries[0].value FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === 2.5) - val result2 = sql("SELECT entries[0] FROM data").collect() - assert(result2.size === 1) - val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] - assert(subresult1.size === 2) - assert(subresult1(0) === 2.5) - assert(subresult1(1) === false) - val result3 = sql("SELECT outerouter FROM data").collect() - val subresult2 = result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]] - assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) - assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) - assert(result3(0)(0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) - .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) - } - - test("nested structs") { - val data = parquetFile(ParquetTestData.testNestedDir3.toString) - .toSchemaRDD - data.registerTempTable("data") - val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() - assert(result1.size === 1) - assert(result1(0).size === 1) - assert(result1(0)(0) === false) - val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() - assert(result2.size === 1) - assert(result2(0).size === 1) - assert(result2(0)(0) === true) - val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() - assert(result3.size === 1) - assert(result3(0).size === 1) - assert(result3(0)(0) === false) - } - - test("simple map") { - val data = TestSQLContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data1 FROM mapTable").collect() - assert(result1.size === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key1", 0) === 1) - assert(result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, _]] - .getOrElse("key2", 0) === 2) - val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() - assert(result2(0)(0) === 1) - } - - test("map with struct values") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - data.registerTempTable("mapTable") - val result1 = sql("SELECT data2 FROM mapTable").collect() - assert(result1.size === 1) - val entry1 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result1(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() - assert(result2.size === 1) - assert(result2(0)(0) === 42.toLong) - assert(result2(0)(1) === "the answer") - } - - test("Writing out Addressbook and reading it back in") { - // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - // has no effect in this test case - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD - result.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpcopy") - val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() - assert(tmpdata.size === 2) - assert(tmpdata(0).size === 2) - assert(tmpdata(0)(0) === "Julien Le Dem") - assert(tmpdata(0)(1) === "Chris Aniszczyk") - assert(tmpdata(1)(0) === "A. Nonymous") - assert(tmpdata(1)(1) === null) - Utils.deleteRecursively(tmpdir) - } - - test("Writing out Map and reading it back in") { - val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - data.saveAsParquetFile(tmpdir.toString) - parquetFile(tmpdir.toString) - .toSchemaRDD - .registerTempTable("tmpmapcopy") - val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() - assert(result1.size === 1) - assert(result1(0)(0) === 2) - val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() - assert(result2.size === 1) - val entry1 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("seven", null) - assert(entry1 != null) - assert(entry1(0) === 42) - assert(entry1(1) === "the answer") - val entry2 = result2(0)(0) - .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] - .getOrElse("eight", null) - assert(entry2 != null) - assert(entry2(0) === 49) - assert(entry2(1) === null) - val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() - assert(result3.size === 1) - assert(result3(0)(0) === 42.toLong) - assert(result3(0)(1) === "the answer") - Utils.deleteRecursively(tmpdir) - } - - test("read/write fixed-length decimals") { - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(precision, scale)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType(19, 10)) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - val tempDir = getTempFilePath("parquetTest").getCanonicalPath - val data = sparkContext.parallelize(0 to 1000) - .map(i => NumericData(i, i / 100.0)) - .select('i, 'd cast DecimalType.Unlimited) - data.saveAsParquetFile(tempDir) - checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) - } - } - - def checkFilter(predicate: Predicate, filterClass: Class[_ <: FilterPredicate]): Unit = { - val filter = ParquetFilters.createFilter(predicate) - assert(filter.isDefined) - assert(filter.get.getClass == filterClass) - } - - test("Pushdown IsNull predicate") { - checkFilter('a.int.isNull, classOf[Operators.Eq[Integer]]) - checkFilter('a.long.isNull, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float.isNull, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double.isNull, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string.isNull, classOf[Operators.Eq[Binary]]) - checkFilter('a.binary.isNull, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown IsNotNull predicate") { - checkFilter('a.int.isNotNull, classOf[Operators.NotEq[Integer]]) - checkFilter('a.long.isNotNull, classOf[Operators.NotEq[java.lang.Long]]) - checkFilter('a.float.isNotNull, classOf[Operators.NotEq[java.lang.Float]]) - checkFilter('a.double.isNotNull, classOf[Operators.NotEq[java.lang.Double]]) - checkFilter('a.string.isNotNull, classOf[Operators.NotEq[Binary]]) - checkFilter('a.binary.isNotNull, classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown EqualTo predicate") { - checkFilter('a.int === 0, classOf[Operators.Eq[Integer]]) - checkFilter('a.long === 0.toLong, classOf[Operators.Eq[java.lang.Long]]) - checkFilter('a.float === 0.toFloat, classOf[Operators.Eq[java.lang.Float]]) - checkFilter('a.double === 0.toDouble, classOf[Operators.Eq[java.lang.Double]]) - checkFilter('a.string === "foo", classOf[Operators.Eq[Binary]]) - checkFilter('a.binary === "foo".getBytes, classOf[Operators.Eq[Binary]]) - } - - test("Pushdown Not(EqualTo) predicate") { - checkFilter(!('a.int === 0), classOf[Operators.NotEq[Integer]]) - checkFilter(!('a.long === 0.toLong), classOf[Operators.NotEq[java.lang.Long]]) - checkFilter(!('a.float === 0.toFloat), classOf[Operators.NotEq[java.lang.Float]]) - checkFilter(!('a.double === 0.toDouble), classOf[Operators.NotEq[java.lang.Double]]) - checkFilter(!('a.string === "foo"), classOf[Operators.NotEq[Binary]]) - checkFilter(!('a.binary === "foo".getBytes), classOf[Operators.NotEq[Binary]]) - } - - test("Pushdown LessThan predicate") { - checkFilter('a.int < 0, classOf[Operators.Lt[Integer]]) - checkFilter('a.long < 0.toLong, classOf[Operators.Lt[java.lang.Long]]) - checkFilter('a.float < 0.toFloat, classOf[Operators.Lt[java.lang.Float]]) - checkFilter('a.double < 0.toDouble, classOf[Operators.Lt[java.lang.Double]]) - checkFilter('a.string < "foo", classOf[Operators.Lt[Binary]]) - checkFilter('a.binary < "foo".getBytes, classOf[Operators.Lt[Binary]]) - } - - test("Pushdown LessThanOrEqual predicate") { - checkFilter('a.int <= 0, classOf[Operators.LtEq[Integer]]) - checkFilter('a.long <= 0.toLong, classOf[Operators.LtEq[java.lang.Long]]) - checkFilter('a.float <= 0.toFloat, classOf[Operators.LtEq[java.lang.Float]]) - checkFilter('a.double <= 0.toDouble, classOf[Operators.LtEq[java.lang.Double]]) - checkFilter('a.string <= "foo", classOf[Operators.LtEq[Binary]]) - checkFilter('a.binary <= "foo".getBytes, classOf[Operators.LtEq[Binary]]) - } - - test("Pushdown GreaterThan predicate") { - checkFilter('a.int > 0, classOf[Operators.Gt[Integer]]) - checkFilter('a.long > 0.toLong, classOf[Operators.Gt[java.lang.Long]]) - checkFilter('a.float > 0.toFloat, classOf[Operators.Gt[java.lang.Float]]) - checkFilter('a.double > 0.toDouble, classOf[Operators.Gt[java.lang.Double]]) - checkFilter('a.string > "foo", classOf[Operators.Gt[Binary]]) - checkFilter('a.binary > "foo".getBytes, classOf[Operators.Gt[Binary]]) - } - - test("Pushdown GreaterThanOrEqual predicate") { - checkFilter('a.int >= 0, classOf[Operators.GtEq[Integer]]) - checkFilter('a.long >= 0.toLong, classOf[Operators.GtEq[java.lang.Long]]) - checkFilter('a.float >= 0.toFloat, classOf[Operators.GtEq[java.lang.Float]]) - checkFilter('a.double >= 0.toDouble, classOf[Operators.GtEq[java.lang.Double]]) - checkFilter('a.string >= "foo", classOf[Operators.GtEq[Binary]]) - checkFilter('a.binary >= "foo".getBytes, classOf[Operators.GtEq[Binary]]) - } - - test("Comparison with null should not be pushed down") { - val predicates = Seq( - 'a.int === null, - !('a.int === null), - - Literal(null) === 'a.int, - !(Literal(null) === 'a.int), - - 'a.int < null, - 'a.int <= null, - 'a.int > null, - 'a.int >= null, - - Literal(null) < 'a.int, - Literal(null) <= 'a.int, - Literal(null) > 'a.int, - Literal(null) >= 'a.int - ) - - predicates.foreach { p => - assert( - ParquetFilters.createFilter(p).isEmpty, - "Comparison predicate with null shouldn't be pushed down") - } - } - - test("Import of simple Parquet files using glob wildcard pattern") { - val testGlobDir = ParquetTestData.testGlobDir.toString - val globPatterns = Array(testGlobDir + "/*/*", testGlobDir + "/spark-*/*", testGlobDir + "/?pa?k-*/*") - globPatterns.foreach { path => - val result = parquetFile(path).collect() - assert(result.size === 45) - result.zipWithIndex.foreach { - case (row, index) => { - val checkBoolean = - if ((index % 15) % 3 == 0) - row(0) == true - else - row(0) == false - assert(checkBoolean === true, s"boolean field value in line $index did not match") - if ((index % 15) % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") - assert(row(2) === "abc", s"string field value in line $index did not match") - assert(row(3) === ((index.toLong % 15) << 33), s"long value in line $index did not match") - assert(row(4) === 2.5F, s"float field value in line $index did not match") - assert(row(5) === 4.5D, s"double field value in line $index did not match") - } - } + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala deleted file mode 100644 index 7b3f8c22af2db..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * A test suite that tests various Parquet queries. - */ -class ParquetQuerySuite2 extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } - } - - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } - } - - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output - - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } - - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) - } - } - - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } - } - - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } - } - - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } - } -} From 3be2a887bf6107b6398e472872b22175ea4ae1f7 Mon Sep 17 00:00:00 2001 From: wangfei Date: Wed, 21 Jan 2015 15:27:42 -0800 Subject: [PATCH 09/22] [SPARK-4984][CORE][WEBUI] Adding a pop-up containing the full job description when it is very long In some case the job description will be very long, such as a long sql. refer to #3718 This PR add a pop-up for job description when it is long. ![image](https://cloud.githubusercontent.com/assets/7018048/5847400/c757cbbc-a207-11e4-891f-528821c2e68d.png) ![image](https://cloud.githubusercontent.com/assets/7018048/5847409/d434b2b4-a207-11e4-8813-03a74b43d766.png) Author: wangfei Closes #3819 from scwf/popup-descrip-ui and squashes the following commits: ba02b83 [wangfei] address comments a7c5e7b [wangfei] spot that it's been truncated fbf6162 [wangfei] Merge branch 'master' into popup-descrip-ui 0bca96d [wangfei] remove no use val 4b55c3b [wangfei] fix style issue 353c6f4 [wangfei] pop up the description of job with a styled read-only text form field --- .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++++++++ .../main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- .../main/scala/org/apache/spark/ui/jobs/StageTable.scala | 3 +-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f02b035a980b1..a1f7133f897ee 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -121,6 +121,14 @@ pre { border: none; } +.description-input { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: nowrap; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 81212708ba524..045c69da06feb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -64,7 +64,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} -
{lastStageDescription}
+ {lastStageDescription} {lastStageName} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e7d6244dcd679..703d43f9c640d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -112,9 +112,8 @@ private[ui] class StageTableBase( stageData <- listener.stageIdToData.get((s.stageId, s.attemptId)) desc <- stageData.description } yield { -
{desc}
+ {desc} } -
{stageDesc.getOrElse("")} {killLink} {nameLink} {details}
} From 9bad062268676aaa66dcbddd1e0ab7f2d7742425 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Jan 2015 16:51:42 -0800 Subject: [PATCH 10/22] [SPARK-5355] make SparkConf thread-safe The SparkConf is not thread-safe, but is accessed by many threads. The getAll() could return parts of the configs if another thread is access it. This PR changes SparkConf.settings to a thread-safe TrieMap. Author: Davies Liu Closes #4143 from davies/safe-conf and squashes the following commits: f8fa1cf [Davies Liu] change to TrieMap a1d769a [Davies Liu] make SparkConf thread-safe --- core/src/main/scala/org/apache/spark/SparkConf.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a0ce107f43b16..f9d4aa4240e9d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.JavaConverters._ +import scala.collection.concurrent.TrieMap import scala.collection.mutable.{HashMap, LinkedHashSet} import org.apache.spark.serializer.KryoSerializer @@ -46,7 +47,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new HashMap[String, String]() + private[spark] val settings = new TrieMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -177,7 +178,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.clone().toArray + def getAll: Array[(String, String)] = settings.toArray /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { From 27bccc5ea9742ca166f6026033e7771c8a90cc0a Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 21 Jan 2015 17:34:18 -0800 Subject: [PATCH 11/22] [SPARK-5202] [SQL] Add hql variable substitution support https://cwiki.apache.org/confluence/display/Hive/LanguageManual+VariableSubstitution This is a block issue for the CLI user, it impacts the existed hql scripts from Hive. Author: Cheng Hao Closes #4003 from chenghao-intel/substitution and squashes the following commits: bb41fd6 [Cheng Hao] revert the removed the implicit conversion af7c31a [Cheng Hao] add hql variable substitution support --- .../apache/spark/sql/hive/HiveContext.scala | 6 ++++-- .../sql/hive/execution/SQLQuerySuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 274f83af5ac03..9d2cfd8e0d669 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} @@ -66,11 +67,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { new this.QueryExecution { val logical = plan } override def sql(sqlText: String): SchemaRDD = { + val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { - super.sql(sqlText) + super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(sqlText))) + new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6bf2dbb5d6e4..7f9f1ac7cd80d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -104,6 +104,24 @@ class SQLQuerySuite extends QueryTest { ) } + test("command substitution") { + sql("set tbl=src") + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + + sql("set hive.variable.substitute=false") // disable the substitution + sql("set tbl2=src") + intercept[Exception] { + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1").collect() + } + + sql("set hive.variable.substitute=true") // enable the substitution + checkAnswer( + sql("SELECT key FROM ${hiveconf:tbl2} ORDER BY key, value limit 1"), + sql("SELECT key FROM src ORDER BY key, value limit 1").collect().toSeq) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), From ca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 21 Jan 2015 21:20:31 -0800 Subject: [PATCH 12/22] [SPARK-3424][MLLIB] cache point distances during k-means|| init This PR ports the following feature implemented in #2634 by derrickburns: * During k-means|| initialization, we should cache costs (squared distances) previously computed. It also contains the following optimization: * aggregate sumCosts directly * ran multiple (#runs) k-means++ in parallel I compared the performance locally on mnist-digit. Before this patch: ![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png) with this patch: ![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png) It is clear that each k-means|| iteration takes about the same amount of time with this patch. Authors: Derrick Burns Xiangrui Meng Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits: 0a875ec [Xiangrui Meng] address comments 4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means|| --- .../spark/mllib/clustering/KMeans.scala | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 6b5c934f015ba..fc46da3a93425 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,45 +279,80 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[VectorWithNorm]) : Array[Array[VectorWithNorm]] = { - // Initialize each run's center to a random point + // Initialize empty centers and point costs. + val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) + var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + + // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) + + /** Merges new centers to centers. */ + def mergeNewCenters(): Unit = { + var r = 0 + while (r < runs) { + centers(r) ++= newCenters(r) + newCenters(r).clear() + r += 1 + } + } // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers + // to their squared distance from that run's centers. Note that only distances between points + // and new centers are computed in each iteration. var step = 0 while (step < initializationSteps) { - val bcCenters = data.context.broadcast(centers) - val sumCosts = data.flatMap { point => - (0 until runs).map { r => - (r, KMeans.pointCost(bcCenters.value(r), point)) - } - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => + val bcNewCenters = data.context.broadcast(newCenters) + val preCosts = costs + costs = data.zip(preCosts).map { case (point, cost) => + Vectors.dense( + Array.tabulate(runs) { r => + math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) + }) + }.cache() + val sumCosts = costs + .aggregate(Vectors.zeros(runs))( + seqOp = (s, v) => { + // s += v + axpy(1.0, v, s) + s + }, + combOp = (s0, s1) => { + // s0 += s1 + axpy(1.0, s1, s0) + s0 + } + ) + preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) - points.flatMap { p => + pointsWithCosts.flatMap { case (p, c) => (0 until runs).filter { r => - rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r) + rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) }.map((_, p)) } }.collect() + mergeNewCenters() chosen.foreach { case (r, p) => - centers(r) += p.toDense + newCenters(r) += p.toDense } step += 1 } + mergeNewCenters() + costs.unpersist(blocking = false) + // Finally, we might have a set of more than k candidate centers for each run; weigh each // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick just k of them val bcCenters = data.context.broadcast(centers) val weightMap = data.flatMap { p => - (0 until runs).map { r => + Iterator.tabulate(runs) { r => ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) From fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 23:06:34 -0800 Subject: [PATCH 13/22] [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 25 +++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 14 ++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c65a9..ed8e6a796f8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94e81..972959885f396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } From 3027f06b4127ab23a43c5ce8cebf721e3b6766e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Jan 2015 23:41:44 -0800 Subject: [PATCH 14/22] [SPARK-5147][Streaming] Delete the received data WAL log periodically This is a refactored fix based on jerryshao 's PR #4037 This enabled deletion of old WAL files containing the received block data. Improvements over #4037 - Respecting the rememberDuration of all receiver streams. In #4037, if there were two receiver streams with multiple remember durations, the deletion would have delete based on the shortest remember duration, thus deleting data prematurely for the receiver stream with longer remember duration. - Added unit test to test creation of receiver WAL, automatic deletion, and respecting of remember duration. jerryshao I am going to merge this ASAP to make it 1.2.1 Thanks for the initial draft of this PR. Made my job much easier. Author: Tathagata Das Author: jerryshao Closes #4149 from tdas/SPARK-5147 and squashes the following commits: 730798b [Tathagata Das] Added comments. c4cf067 [Tathagata Das] Minor fixes 2579b27 [Tathagata Das] Refactored the fix to make sure that the cleanup respects the remember duration of all the receiver streams 2736fd1 [jerryshao] Delete the old WAL log periodically --- .../apache/spark/streaming/DStreamGraph.scala | 8 + .../dstream/ReceiverInputDStream.scala | 11 -- .../streaming/receiver/ReceiverMessage.scala | 5 +- .../receiver/ReceiverSupervisorImpl.scala | 9 + .../streaming/scheduler/JobGenerator.scala | 11 +- .../scheduler/ReceivedBlockTracker.scala | 1 - .../streaming/scheduler/ReceiverTracker.scala | 18 +- .../spark/streaming/util/HdfsUtils.scala | 2 +- .../spark/streaming/ReceiverSuite.scala | 157 ++++++++++++++---- 9 files changed, 172 insertions(+), 50 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index e59c24adb84af..0e285d6088ec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } + /** + * Get the maximum remember duration across all the input streams. This is a conservative but + * safe remember duration which can be used to perform cleanup operations. + */ + def getMaxInputStreamRememberDuration(): Duration = { + inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { logDebug("DStreamGraph.writeObject used") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index afd3c4bc4c4fe..8be04314c4285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } - - /** - * Clear metadata that are older than `rememberDuration` of this DStream. - * This is an internal method that should not be called directly. This - * implementation overrides the default implementation to clear received - * block information. - */ - private[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration) - } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index ab9fa192191aa..7bf3c33319491 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.receiver -/** Messages sent to the NetworkReceiver. */ +import org.apache.spark.streaming.Time + +/** Messages sent to the Receiver. */ private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage +private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index d7229c2b96d0b..716cf2c7f32fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl( case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) + case CleanupOldBlocks(threshTime) => + logDebug("Received delete old batch signal") + cleanupOldBlocks(threshTime) } def ref = self @@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl( /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) + + private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = { + logDebug(s"Cleaning up blocks older then $cleanupThreshTime") + receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 39b66e1130768..d86f852aba97e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration) // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { eventActor ! DoCheckpoint(time) } else { + // If checkpointing is not enabled, then delete metadata information about + // received blocks (block data not saved in any case). Otherwise, wait for + // checkpointing of this batch to complete. + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + + // All the checkpoint information about which batches have been processed, etc have + // been saved to checkpoints, so its safe to delete block metadata and data WAL files + val maxRememberDuration = graph.getMaxInputStreamRememberDuration() + jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index c3d9d7b6813d3..ef23b5c79f2e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker( writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) - log } /** Stop the block tracker. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 8dbb42a86e3bd..4f998869731ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -24,9 +24,8 @@ import scala.language.existentials import akka.actor._ import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - /** Clean up metadata older than the given threshold time */ - def cleanupOldMetadata(cleanupThreshTime: Time) { + /** + * Clean up the data and metadata of blocks and batches that are strictly + * older than the threshold time. Note that this does not + */ + def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) { + // Clean up old block and batch metadata receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) + + // Signal the receivers to delete old block data + if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + logInfo(s"Cleanup old received batch data: $cleanupThreshTime") + receiverInfo.values.flatMap { info => Option(info.actor) } + .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + } } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 27a28bab83ed5..858ba3c9eb4e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -63,7 +63,7 @@ private[streaming] object HdfsUtils { } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { - // For local file systems, return the raw loca file system, such calls to flush() + // For local file systems, return the raw local file system, such calls to flush() // actually flushes the stream. val fs = path.getFileSystem(conf) fs match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e26c0c6859e57..e8c34a9ee40b9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -17,21 +17,26 @@ package org.apache.spark.streaming +import java.io.File import java.nio.ByteBuffer import java.util.concurrent.Semaphore +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkConf -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor} -import org.scalatest.FunSuite +import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ + /** Testsuite for testing the network receiver behavior */ -class ReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("receiver life cycle") { @@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts { val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") - println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) assert( // the first and last block may be incomplete, so we slice them out recordedBlocks.drop(1).dropRight(1).forall { block => @@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts { ) } - /** - * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + * Test whether write ahead logs are generated by received, + * and automatically cleaned up. The clean up must be aware of the + * remember duration of the input streams. E.g., input streams on which window() + * has been applied must remember the data for longer, and hence corresponding + * WALs should be cleaned later. */ - class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - @volatile var otherThread: Thread = null - @volatile var receiving = false - @volatile var onStartCalled = false - @volatile var onStopCalled = false - - def onStart() { - otherThread = new Thread() { - override def run() { - receiving = true - while(!isStopped()) { - Thread.sleep(10) - } + test("write ahead log - generating and cleaning") { + val sparkConf = new SparkConf() + .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers + .setAppName(framework) + .set("spark.ui.enabled", "true") + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val batchDuration = Milliseconds(500) + val tempDirectory = Files.createTempDir() + val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) + val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) + val allLogFiles1 = new mutable.HashSet[String]() + val allLogFiles2 = new mutable.HashSet[String]() + logInfo("Temp checkpoint directory = " + tempDirectory) + + def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = { + (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2)) + } + + def getCurrentLogFiles(logDirectory: File): Seq[String] = { + try { + if (logDirectory.exists()) { + logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + } else { + Seq.empty } + } catch { + case e: Exception => + Seq.empty } - onStartCalled = true - otherThread.start() - } - def onStop() { - onStopCalled = true - otherThread.join() + def printLogFiles(message: String, files: Seq[String]) { + logInfo(s"$message (${files.size} files):\n" + files.mkString("\n")) } - def reset() { - receiving = false - onStartCalled = false - onStopCalled = false + withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => + tempDirectory.deleteOnExit() + val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) + val receiverStream1 = ssc.receiverStream(receiver1) + val receiverStream2 = ssc.receiverStream(receiver2) + receiverStream1.register() + receiverStream2.window(batchDuration * 6).register() // 3 second window + ssc.checkpoint(tempDirectory.getAbsolutePath()) + ssc.start() + + // Run until sufficient WAL files have been generated and + // the first WAL files has been deleted + eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) { + val (logFiles1, logFiles2) = getBothCurrentLogFiles() + allLogFiles1 ++= logFiles1 + allLogFiles2 ++= logFiles2 + if (allLogFiles1.size > 0) { + assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head)) + } + if (allLogFiles2.size > 0) { + assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head)) + } + assert(allLogFiles1.size >= 7) + assert(allLogFiles2.size >= 7) + } + ssc.stop(stopSparkContext = true, stopGracefully = true) + + val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted + val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted + val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles() + + printLogFiles("Receiver 0: all", sortedAllLogFiles1) + printLogFiles("Receiver 0: left", leftLogFiles1) + printLogFiles("Receiver 1: all", sortedAllLogFiles2) + printLogFiles("Receiver 1: left", leftLogFiles2) + + // Verify that necessary latest log files are not deleted + // receiverStream1 needs to retain just the last batch = 1 log file + // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files + assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains)) + assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains)) } } @@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts { } } +/** + * An implementation of Receiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false + + def onStart() { + otherThread = new Thread() { + override def run() { + receiving = true + var count = 0 + while(!isStopped()) { + if (sendData) { + store(count) + count += 1 + } + Thread.sleep(10) + } + } + } + onStartCalled = true + otherThread.start() + } + + def onStop() { + onStopCalled = true + otherThread.join() + } + + def reset() { + receiving = false + onStartCalled = false + onStopCalled = false + } +} + From 246111d179a2f3f6b97a5c2b121d8ddbfd1c9aad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Jan 2015 08:16:35 -0800 Subject: [PATCH 15/22] [SPARK-5365][MLlib] Refactor KMeans to reduce redundant data If a point is selected as new centers for many runs, it would collect many redundant data. This pr refactors it. Author: Liang-Chi Hsieh Closes #4159 from viirya/small_refactor_kmeans and squashes the following commits: 25487e6 [Liang-Chi Hsieh] Refactor codes to reduce redundant data. --- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fc46da3a93425..11633e8242313 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -328,14 +328,15 @@ class KMeans private ( val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => - (0 until runs).filter { r => + val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) - }.map((_, p)) + } + if (rs.length > 0) Some(p, rs) else None } }.collect() mergeNewCenters() - chosen.foreach { case (r, p) => - newCenters(r) += p.toDense + chosen.foreach { case (p, rs) => + rs.foreach(newCenters(_) += p.toDense) } step += 1 } From 820ce03597350257abe0c5c96435c555038e3e6c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 22 Jan 2015 13:49:35 -0600 Subject: [PATCH 16/22] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAlloca... ...tor Author: Sandy Ryza Closes #4164 from sryza/sandy-spark-5370 and squashes the following commits: 0c8d736 [Sandy Ryza] SPARK-5370. [YARN] Remove some unnecessary synchronization in YarnAllocator --- .../spark/deploy/yarn/YarnAllocator.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4c35b60c57df3..d00f29665a58f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -60,7 +60,6 @@ private[yarn] class YarnAllocator( import YarnAllocator._ - // These two complementary data structures are locked on allocatedHostToContainersMap. // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] @@ -355,20 +354,18 @@ private[yarn] class YarnAllocator( } } - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).get + val containerSet = allocatedHostToContainersMap.get(host).get - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) } + + allocatedContainerToHostMap.remove(containerId) } } } From 3c3fa632e6ba45ce536065aa1145698385301fb2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 21:58:53 -0800 Subject: [PATCH 17/22] [SPARK-5233][Streaming] Fix error replaying of WAL introduced bug Because of lacking of `BlockAllocationEvent` in WAL recovery, the dangled event will mix into the new batch, which will lead to the wrong result. Details can be seen in [SPARK-5233](https://issues.apache.org/jira/browse/SPARK-5233). Author: jerryshao Closes #4032 from jerryshao/SPARK-5233 and squashes the following commits: f0b0c0b [jerryshao] Further address the comments a237c75 [jerryshao] Address the comments e356258 [jerryshao] Fix bug in unit test 558bdc3 [jerryshao] Correctly replay the WAL log when recovering from failure --- .../examples/streaming/KafkaWordCount.scala | 2 +- .../streaming/scheduler/JobGenerator.scala | 18 +++++++++++------ .../scheduler/ReceivedBlockTracker.scala | 12 ++++++++--- .../streaming/ReceivedBlockTrackerSuite.scala | 20 +++++++++---------- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 2adc63f7ff30e..387c0e421334b 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -76,7 +76,7 @@ object KafkaWordCountProducer { val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args - // Zookeper connection properties + // Zookeeper connection properties val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index d86f852aba97e..8632c94349bf9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{ActorRef, ActorSystem, Props, Actor} -import org.apache.spark.{SparkException, SparkEnv, Logging} -import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} -import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} import scala.util.{Failure, Success, Try} +import akka.actor.{ActorRef, Props, Actor} + +import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} + /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent @@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) - timesToReschedule.foreach(time => + timesToReschedule.foreach { time => + // Allocate the related blocks when recovering from failure, because some blocks that were + // added but not allocated, are dangling in the queue after recovering, we have to allocate + // those blocks to the next batch, which is the batch they were supposed to go. + jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time))) - ) + } // Restart the timer timer.start(restartTime.milliseconds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index ef23b5c79f2e1..e19ac939f9ac5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker( extends Logging { private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo] - + private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] private val logManagerOption = createLogManager() @@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker( lastAllocatedBatchTime = batchTime allocatedBlocks } else { - throw new SparkException(s"Unexpected allocation of blocks, " + - s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ") + // This situation occurs when: + // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, + // possibly processed batch job or half-processed batch job need to be processed again, + // so the batchTime will be equal to lastAllocatedBatchTime. + // 2. Slow checkpointing makes recovered batch time older than WAL recovered + // lastAllocatedBatchTime. + // This situation will only occurs in recovery time. + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index de7e9d624bf6b..fbb7b0bfebafc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.allocateBlocksToBatch(2) receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty - // Verify that batch 2 cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(2) - } + // Verify that older batches have no operation on batch allocation, + // will return the same blocks as previously allocated. + receivedBlockTracker.allocateBlocksToBatch(1) + receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos - // Verify that older batches cannot be allocated again - intercept[SparkException] { - receivedBlockTracker.allocateBlocksToBatch(1) - } + blockInfos.map(receivedBlockTracker.addBlock) + receivedBlockTracker.allocateBlocksToBatch(2) + receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty + receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } test("block addition, block to batch allocation and cleanup with write ahead log") { @@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 } - + test("enabling write ahead log but not setting checkpoint dir") { conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") intercept[SparkException] { createTracker(setCheckpointDir = false) } } - + test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) From e0f7fb7f9f497b34d42f9ba147197cf9ffc51607 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 22 Jan 2015 22:04:21 -0800 Subject: [PATCH 18/22] [SPARK-5315][Streaming] Fix reduceByWindow Java API not work bug `reduceByWindow` for Java API is actually not Java compatible, change to make it Java compatible. Current solution is to deprecate the old one and add a new API, but since old API actually is not correct, so is keeping the old one meaningful? just to keep the binary compatible? Also even adding new API still need to add to Mima exclusion, I'm not sure to change the API, or deprecate the old API and add a new one, which is the best solution? Author: jerryshao Closes #4104 from jerryshao/SPARK-5315 and squashes the following commits: 5bc8987 [jerryshao] Address the comment c7aa1b4 [jerryshao] Deprecate the old one to keep binary compatible 8e9dc67 [jerryshao] Fix JavaDStream reduceByWindow signature error --- project/MimaExcludes.scala | 4 ++++ .../streaming/api/java/JavaDStreamLike.scala | 20 +++++++++++++++++++ .../apache/spark/streaming/JavaAPISuite.java | 20 +++++++++++++++++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 127973b658190..bc5d81f12d746 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -90,6 +90,10 @@ object MimaExcludes { // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") ) case v if v.startsWith("1.2") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index e0542eda1383f..c382a12f4d099 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @deprecated As this API is not Java compatible. */ + @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") def reduceByWindow( reduceFunc: (T, T) => T, windowDuration: Duration, @@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: JFunction2[T, T, T], + windowDuration: Duration, + slideDuration: Duration + ): JavaDStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. However, the reduction is done incrementally diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index d92e7fe899a09..d4c40745658c2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -306,7 +306,17 @@ public void testReduce() { @SuppressWarnings("unchecked") @Test - public void testReduceByWindow() { + public void testReduceByWindowWithInverse() { + testReduceByWindow(true); + } + + @SuppressWarnings("unchecked") + @Test + public void testReduceByWindowWithoutInverse() { + testReduceByWindow(false); + } + + private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), @@ -319,8 +329,14 @@ public void testReduceByWindow() { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(), + JavaDStream reducedWindowed = null; + if (withInverse) { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + } else { + reducedWindowed = stream.reduceByWindow(new IntegerSum(), + new Duration(2000), new Duration(1000)); + } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); From ea74365b7c5a3ac29cae9ba66f140f1fa5e8d312 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 22 Jan 2015 22:09:13 -0800 Subject: [PATCH 19/22] [SPARK-3541][MLLIB] New ALS implementation with improved storage This PR adds a new ALS implementation to `spark.ml` using the pipeline API, which should be able to scale to billions of ratings. Compared with the ALS under `spark.mllib`, the new implementation 1. uses the same algorithm, 2. uses float type for ratings, 3. uses primitive arrays to avoid GC, 4. sorts and compresses ratings on each block so that we can solve least squares subproblems one by one using only one normal equation instance. The following figure shows performance comparison on copies of the Amazon Reviews dataset using a 16-node (m3.2xlarge) EC2 cluster (the same setup as in http://databricks.com/blog/2014/07/23/scalable-collaborative-filtering-with-spark-mllib.html): ![als-wip](https://cloud.githubusercontent.com/assets/829644/5659447/4c4ff8e0-96c7-11e4-87a9-73c1c63d07f3.png) I keep the `spark.mllib`'s ALS untouched for easy comparison. If the new implementation works well, I'm going to match the features of the ALS under `spark.mllib` and then make it a wrapper of the new implementation, in a separate PR. TODO: - [X] Add unit tests for implicit preferences. Author: Xiangrui Meng Closes #3720 from mengxr/SPARK-3541 and squashes the following commits: 1b9e852 [Xiangrui Meng] fix compile 5129be9 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-3541 dd0d0e8 [Xiangrui Meng] simplify test code c627de3 [Xiangrui Meng] add tests for implicit feedback b84f41c [Xiangrui Meng] address comments a76da7b [Xiangrui Meng] update ALS tests 2a8deb3 [Xiangrui Meng] add some ALS tests 857e876 [Xiangrui Meng] add tests for rating block and encoded block d3c1ac4 [Xiangrui Meng] rename some classes for better code readability add more doc and comments 213d163 [Xiangrui Meng] org imports 771baf3 [Xiangrui Meng] chol doc update ca9ad9d [Xiangrui Meng] add unit tests for chol b4fd17c [Xiangrui Meng] add unit tests for NormalEquation d0f99d3 [Xiangrui Meng] add tests for LocalIndexEncoder 80b8e61 [Xiangrui Meng] fix imports 4937fd4 [Xiangrui Meng] update ALS example 56c253c [Xiangrui Meng] rename product to item bce8692 [Xiangrui Meng] doc for parameters and project the output columns 3f2d81a [Xiangrui Meng] add doc 1efaecf [Xiangrui Meng] add example code 8ae86b5 [Xiangrui Meng] add a working copy of the new ALS implementation --- .../spark/examples/ml/MovieLensALS.scala | 174 ++++ .../apache/spark/ml/recommendation/ALS.scala | 973 ++++++++++++++++++ .../spark/mllib/recommendation/ALS.scala | 2 +- .../spark/ml/recommendation/ALSSuite.scala | 435 ++++++++ .../spark/mllib/recommendation/ALSSuite.scala | 2 +- 5 files changed, 1584 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala new file mode 100644 index 0000000000000..cf62772b92651 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.recommendation.ALS +import org.apache.spark.sql.{Row, SQLContext} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + * Run with + * {{{ + * bin/run-example ml.MovieLensALS + * }}} + */ +object MovieLensALS { + + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + + case class Movie(movieId: Int, title: String, genres: Seq[String]) + + object Movie { + def parseMovie(str: String): Movie = { + val fields = str.split("::") + assert(fields.size == 3) + Movie(fields(0).toInt, fields(1), fields(2).split("|")) + } + } + + case class Params( + ratings: String = null, + movies: String = null, + maxIter: Int = 10, + regParam: Double = 0.1, + rank: Int = 10, + numBlocks: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[String]("ratings") + .required() + .text("path to a MovieLens dataset of ratings") + .action((x, c) => c.copy(ratings = x)) + opt[String]("movies") + .required() + .text("path to a MovieLens dataset of movies") + .action((x, c) => c.copy(movies = x)) + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("maxIter") + .text(s"max number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Int]("numBlocks") + .text(s"number of blocks, default: ${defaultParams.numBlocks}") + .action((x, c) => c.copy(numBlocks = x)) + note( + """ + |Example command line to run this app: + | + | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ + | examples/target/scala-*/spark-examples-*.jar \ + | --rank 10 --maxIter 15 --regParam 0.1 \ + | --movies path/to/movielens/movies.dat \ + | --ratings path/to/movielens/ratings.dat + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.userId).distinct().count() + val numMovies = ratings.map(_.movieId).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val als = new ALS() + .setUserCol("userId") + .setItemCol("movieId") + .setRank(params.rank) + .setMaxIter(params.maxIter) + .setRegParam(params.regParam) + .setNumBlocks(params.numBlocks) + + val model = als.fit(training) + + val predictions = model.transform(test).cache() + + // Evaluate the model. + // TODO: Create an evaluator to compute RMSE. + val mse = predictions.select('rating, 'prediction) + .flatMap { case Row(rating: Float, prediction: Float) => + val err = rating.toDouble - prediction + val err2 = err * err + if (err2.isNaN) { + None + } else { + Some(err2) + } + }.mean() + val rmse = math.sqrt(mse) + println(s"Test RMSE = $rmse.") + + // Inspect false positives. + predictions.registerTempTable("prediction") + sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") + sqlContext.sql( + """ + |SELECT userId, prediction.movieId, title, rating, prediction + | FROM prediction JOIN movie ON prediction.movieId = movie.movieId + | WHERE rating <= 1 AND prediction >= 4 + | LIMIT 100 + """.stripMargin) + .collect() + .foreach(println) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala new file mode 100644 index 0000000000000..2d89e76a4c8b2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -0,0 +1,973 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.{util => ju} + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common params for ALS. + */ +private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam + with HasPredictionCol { + + /** Param for rank of the matrix factorization. */ + val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + def getRank: Int = get(rank) + + /** Param for number of user blocks. */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + def getNumUserBlocks: Int = get(numUserBlocks) + + /** Param for number of item blocks. */ + val numItemBlocks = + new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + def getNumItemBlocks: Int = get(numItemBlocks) + + /** Param to decide whether to use implicit preference. */ + val implicitPrefs = + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + def getImplicitPrefs: Boolean = get(implicitPrefs) + + /** Param for the alpha parameter in the implicit preference formulation. */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + def getAlpha: Double = get(alpha) + + /** Param for the column name for user ids. */ + val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + def getUserCol: String = get(userCol) + + /** Param for the column name for item ids. */ + val itemCol = + new Param[String](this, "itemCol", "column name for item ids", Some("item")) + def getItemCol: String = get(itemCol) + + /** Param for the column name for ratings. */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + def getRatingCol: String = get(ratingCol) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @param paramMap extra params + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + assert(schema(map(userCol)).dataType == IntegerType) + assert(schema(map(itemCol)).dataType== IntegerType) + val ratingType = schema(map(ratingCol)).dataType + assert(ratingType == FloatType || ratingType == DoubleType) + val predictionColName = map(predictionCol) + assert(!schema.fieldNames.contains(predictionColName), + s"Prediction column $predictionColName already exists.") + val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false) + StructType(newFields) + } +} + +/** + * Model fitted by ALS. + */ +class ALSModel private[ml] ( + override val parent: ALS, + override val fittingParamMap: ParamMap, + k: Int, + userFactors: RDD[(Int, Array[Float])], + itemFactors: RDD[(Int, Array[Float])]) + extends Model[ALSModel] with ALSParams { + + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + import dataset.sqlContext._ + import org.apache.spark.ml.recommendation.ALSModel.Factor + val map = this.paramMap ++ paramMap + // TODO: Add DSL to simplify the code here. + val instanceTable = s"instance_$uid" + val userTable = s"user_$uid" + val itemTable = s"item_$uid" + val instances = dataset.as(Symbol(instanceTable)) + val users = userFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(userTable)) + val items = itemFactors.map { case (id, features) => + Factor(id, features) + }.as(Symbol(itemTable)) + val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + if (userFeatures != null && itemFeatures != null) { + blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + } else { + Float.NaN + } + } + val inputColumns = dataset.schema.fieldNames + val prediction = + predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) + val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction + instances + .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) + .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + .select(outputColumns: _*) + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private object ALSModel { + /** Case class to convert factors to SchemaRDDs */ + private case class Factor(id: Int, features: Seq[Float]) +} + +/** + * Alternating Least Squares (ALS) matrix factorization. + * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by pre-computing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * indicated user + * preferences rather than explicit ratings given to items. + */ +class ALS extends Estimator[ALSModel] with ALSParams { + + import org.apache.spark.ml.recommendation.ALS.Rating + + def setRank(value: Int): this.type = set(rank, value) + def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + def setAlpha(value: Double): this.type = set(alpha, value) + def setUserCol(value: String): this.type = set(userCol, value) + def setItemCol(value: String): this.type = set(itemCol, value) + def setRatingCol(value: String): this.type = set(ratingCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setRegParam(value: Double): this.type = set(regParam, value) + + /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + def setNumBlocks(value: Int): this.type = { + setNumUserBlocks(value) + setNumItemBlocks(value) + this + } + + setMaxIter(20) + setRegParam(1.0) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val ratings = + dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } + val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), + numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), + maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), + alpha = map(alpha)) + val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) + Params.inheritValues(map, this, model) + model + } + + override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +private[recommendation] object ALS extends Logging { + + /** Rating class for better code readability. */ + private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + + /** Cholesky solver for least square problems. */ + private[recommendation] class CholeskySolver { + + private val upper = "U" + private val info = new intW(0) + + /** + * Solves a least squares problem with L2 regularization: + * + * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * + * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) + * @param lambda regularization constant, which will be scaled by n + * @return the solution x + */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val k = ne.k + // Add scaled lambda to the diagonals of AtA. + val scaledlambda = lambda * ne.n + var i = 0 + var j = 2 + while (i < ne.triK) { + ne.ata(i) += scaledlambda + i += j + j += 1 + } + lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dppsv returned $code.") + val x = new Array[Float](k) + i = 0 + while (i < k) { + x(i) = ne.atb(i).toFloat + i += 1 + } + ne.reset() + x + } + } + + /** Representing a normal equation (ALS' subproblem). */ + private[recommendation] class NormalEquation(val k: Int) extends Serializable { + + /** Number of entries in the upper triangular part of a k-by-k matrix. */ + val triK = k * (k + 1) / 2 + /** A^T^ * A */ + val ata = new Array[Double](triK) + /** A^T^ * b */ + val atb = new Array[Double](k) + /** Number of observations. */ + var n = 0 + + private val da = new Array[Double](k) + private val upper = "U" + + private def copyToDouble(a: Array[Float]): Unit = { + var i = 0 + while (i < k) { + da(i) = a(i) + i += 1 + } + } + + /** Adds an observation. */ + def add(a: Array[Float], b: Float): this.type = { + require(a.size == k) + copyToDouble(a) + blas.dspr(upper, k, 1.0, da, 1, ata) + blas.daxpy(k, b.toDouble, da, 1, atb, 1) + n += 1 + this + } + + /** + * Adds an observation with implicit feedback. Note that this does not increment the counter. + */ + def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + require(a.size == k) + // Extension to the original paper to handle b < 0. confidence is a function of |b| instead + // so that it is never negative. + val confidence = 1.0 + alpha * math.abs(b) + copyToDouble(a) + blas.dspr(upper, k, confidence - 1.0, da, 1, ata) + // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. + if (b > 0) { + blas.daxpy(k, confidence, da, 1, atb, 1) + } + this + } + + /** Merges another normal equation object. */ + def merge(other: NormalEquation): this.type = { + require(other.k == k) + blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) + n += other.n + this + } + + /** Resets everything to zero, which should be called after each solve. */ + def reset(): Unit = { + ju.Arrays.fill(ata, 0.0) + ju.Arrays.fill(atb, 0.0) + n = 0 + } + } + + /** + * Implementation of the ALS algorithm. + */ + private def train( + ratings: RDD[Rating], + rank: Int = 10, + numUserBlocks: Int = 10, + numItemBlocks: Int = 10, + maxIter: Int = 10, + regParam: Double = 1.0, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { + val userPart = new HashPartitioner(numUserBlocks) + val itemPart = new HashPartitioner(numItemBlocks) + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() + val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + // materialize blockRatings and user blocks + userOutBlocks.count() + val swappedBlockRatings = blockRatings.map { + case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => + ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) + } + val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + // materialize item blocks + itemOutBlocks.count() + var userFactors = initialize(userInBlocks, rank) + var itemFactors = initialize(itemInBlocks, rank) + if (implicitPrefs) { + for (iter <- 1 to maxIter) { + userFactors.setName(s"userFactors-$iter").persist() + val previousItemFactors = itemFactors + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder, implicitPrefs, alpha) + previousItemFactors.unpersist() + itemFactors.setName(s"itemFactors-$iter").persist() + val previousUserFactors = userFactors + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder, implicitPrefs, alpha) + previousUserFactors.unpersist() + } + } else { + for (iter <- 0 until maxIter) { + itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, + userLocalIndexEncoder) + userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, + itemLocalIndexEncoder) + } + } + val userIdAndFactors = userInBlocks + .mapValues(_.srcIds) + .join(userFactors) + .values + .setName("userFactors") + .cache() + userIdAndFactors.count() + itemFactors.unpersist() + val itemIdAndFactors = itemInBlocks + .mapValues(_.srcIds) + .join(itemFactors) + .values + .setName("itemFactors") + .cache() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + val userOutput = userIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => + ids.view.zip(factors) + } + (userOutput, itemOutput) + } + + /** + * Factor block that stores factors (Array[Float]) in an Array. + */ + private type FactorBlock = Array[Array[Float]] + + /** + * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to + * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the + * src factors in this block to send to dst block 0. + */ + private type OutBlock = Array[Array[Int]] + + /** + * In-link block for computing src (user/item) factors. This includes the original src IDs + * of the elements within this block as well as encoded dst (item/user) indices and corresponding + * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original + * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. + * For example, if we have an in-link record + * + * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * + * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which + * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * + * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can + * compute src factors one after another using only one normal equation instance. + * + * @param srcIds src ids (ordered) + * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and + * ratings are associated with srcIds(i). + * @param dstEncodedIndices encoded dst indices + * @param ratings ratings + * + * @see [[LocalIndexEncoder]] + */ + private[recommendation] case class InBlock( + srcIds: Array[Int], + dstPtrs: Array[Int], + dstEncodedIndices: Array[Int], + ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = ratings.size + + require(dstEncodedIndices.size == size) + require(dstPtrs.size == srcIds.size + 1) + } + + /** + * Initializes factors randomly given the in-link blocks. + * + * @param inBlocks in-link blocks + * @param rank rank + * @return initialized factor blocks + */ + private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + // Choose a unit vector uniformly at random from the unit sphere, but from the + // "first quadrant" where all elements are nonnegative. This can be done by choosing + // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. + // This appears to create factorizations that have a slightly better reconstruction + // (<1%) compared picking elements uniformly at random in [0,1]. + inBlocks.map { case (srcBlockId, inBlock) => + val random = new XORShiftRandom(srcBlockId) + val factors = Array.fill(inBlock.srcIds.size) { + val factor = Array.fill(rank)(random.nextGaussian().toFloat) + val nrm = blas.snrm2(rank, factor, 1) + blas.sscal(rank, 1.0f / nrm, factor, 1) + factor + } + (srcBlockId, factors) + } + } + + /** + * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. + */ + private[recommendation] + case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + /** Size of the block. */ + val size: Int = srcIds.size + require(dstIds.size == size) + require(ratings.size == size) + } + + /** + * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. + */ + private[recommendation] class RatingBlockBuilder extends Serializable { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstIds = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + var size = 0 + + /** Adds a rating. */ + def add(r: Rating): this.type = { + size += 1 + srcIds += r.user + dstIds += r.item + ratings += r.rating + this + } + + /** Merges another [[RatingBlockBuilder]]. */ + def merge(other: RatingBlock): this.type = { + size += other.srcIds.size + srcIds ++= other.srcIds + dstIds ++= other.dstIds + ratings ++= other.ratings + this + } + + /** Builds a [[RatingBlock]]. */ + def build(): RatingBlock = { + RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + } + } + + /** + * Partitions raw ratings into blocks. + * + * @param ratings raw ratings + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * + * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) + */ + private def partitionRatings( + ratings: RDD[Rating], + srcPart: Partitioner, + dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + + /* The implementation produces the same result as the following but generates less objects. + + ratings.map { r => + ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + }.aggregateByKey(new RatingBlockBuilder)( + seqOp = (b, r) => b.add(r), + combOp = (b0, b1) => b0.merge(b1.build())) + .mapValues(_.build()) + */ + + val numPartitions = srcPart.numPartitions * dstPart.numPartitions + ratings.mapPartitions { iter => + val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + iter.flatMap { r => + val srcBlockId = srcPart.getPartition(r.user) + val dstBlockId = dstPart.getPartition(r.item) + val idx = srcBlockId + srcPart.numPartitions * dstBlockId + val builder = builders(idx) + builder.add(r) + if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k + builders(idx) = new RatingBlockBuilder + Iterator.single(((srcBlockId, dstBlockId), builder.build())) + } else { + Iterator.empty + } + } ++ { + builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => + val srcBlockId = idx % srcPart.numPartitions + val dstBlockId = idx / srcPart.numPartitions + ((srcBlockId, dstBlockId), block.build()) + } + } + }.groupByKey().mapValues { blocks => + val builder = new RatingBlockBuilder + blocks.foreach(builder.merge) + builder.build() + }.setName("ratingBlocks") + } + + /** + * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. + * @param encoder encoder for dst indices + */ + private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + + private val srcIds = mutable.ArrayBuilder.make[Int] + private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] + private val ratings = mutable.ArrayBuilder.make[Float] + + /** + * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. + * + * @param dstBlockId dst block ID + * @param srcIds original src IDs + * @param dstLocalIndices dst local indices + * @param ratings ratings + */ + def add( + dstBlockId: Int, + srcIds: Array[Int], + dstLocalIndices: Array[Int], + ratings: Array[Float]): this.type = { + val sz = srcIds.size + require(dstLocalIndices.size == sz) + require(ratings.size == sz) + this.srcIds ++= srcIds + this.ratings ++= ratings + var j = 0 + while (j < sz) { + this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) + j += 1 + } + this + } + + /** Builds a [[UncompressedInBlock]]. */ + def build(): UncompressedInBlock = { + new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) + } + } + + /** + * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. + */ + private[recommendation] class UncompressedInBlock( + val srcIds: Array[Int], + val dstEncodedIndices: Array[Int], + val ratings: Array[Float]) { + + /** Size the of block. */ + def size: Int = srcIds.size + + /** + * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a + * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. + */ + def compress(): InBlock = { + val sz = size + assert(sz > 0, "Empty in-link block should not exist.") + sort() + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val dstCountsBuilder = mutable.ArrayBuilder.make[Int] + var preSrcId = srcIds(0) + uniqueSrcIdsBuilder += preSrcId + var curCount = 1 + var i = 1 + var j = 0 + while (i < sz) { + val srcId = srcIds(i) + if (srcId != preSrcId) { + uniqueSrcIdsBuilder += srcId + dstCountsBuilder += curCount + preSrcId = srcId + j += 1 + curCount = 0 + } + curCount += 1 + i += 1 + } + dstCountsBuilder += curCount + val uniqueSrcIds = uniqueSrcIdsBuilder.result() + val numUniqueSrdIds = uniqueSrcIds.size + val dstCounts = dstCountsBuilder.result() + val dstPtrs = new Array[Int](numUniqueSrdIds + 1) + var sum = 0 + i = 0 + while (i < numUniqueSrdIds) { + sum += dstCounts(i) + i += 1 + dstPtrs(i) = sum + } + InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) + } + + private def sort(): Unit = { + val sz = size + // Since there might be interleaved log messages, we insert a unique id for easy pairing. + val sortId = Utils.random.nextInt() + logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") + val start = System.nanoTime() + val sorter = new Sorter(new UncompressedInBlockSort) + sorter.sort(this, 0, size, Ordering[IntWrapper]) + val duration = (System.nanoTime() - start) / 1e9 + logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") + } + } + + /** + * A wrapper that holds a primitive integer key. + * + * @see [[UncompressedInBlockSort]] + */ + private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + override def compare(that: IntWrapper): Int = { + key.compare(that.key) + } + } + + /** + * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. + */ + private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + + override def newKey(): IntWrapper = new IntWrapper() + + override def getKey( + data: UncompressedInBlock, + pos: Int, + reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data.srcIds(pos)) + } else { + reuse.key = data.srcIds(pos) + reuse + } + } + + override def getKey( + data: UncompressedInBlock, + pos: Int): IntWrapper = { + getKey(data, pos, null) + } + + private def swapElements[@specialized(Int, Float) T]( + data: Array[T], + pos0: Int, + pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } + + override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + swapElements(data.srcIds, pos0, pos1) + swapElements(data.dstEncodedIndices, pos0, pos1) + swapElements(data.ratings, pos0, pos1) + } + + override def copyRange( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int, + length: Int): Unit = { + System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) + System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) + System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) + } + + override def allocate(length: Int): UncompressedInBlock = { + new UncompressedInBlock( + new Array[Int](length), new Array[Int](length), new Array[Float](length)) + } + + override def copyElement( + src: UncompressedInBlock, + srcPos: Int, + dst: UncompressedInBlock, + dstPos: Int): Unit = { + dst.srcIds(dstPos) = src.srcIds(srcPos) + dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) + dst.ratings(dstPos) = src.ratings(srcPos) + } + } + + /** + * Creates in-blocks and out-blocks from rating blocks. + * @param prefix prefix for in/out-block names + * @param ratingBlocks rating blocks + * @param srcPart partitioner for src IDs + * @param dstPart partitioner for dst IDs + * @return (in-blocks, out-blocks) + */ + private def makeBlocks( + prefix: String, + ratingBlocks: RDD[((Int, Int), RatingBlock)], + srcPart: Partitioner, + dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + val inBlocks = ratingBlocks.map { + case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => + // The implementation is a faster version of + // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap + val start = System.nanoTime() + val dstIdSet = new OpenHashSet[Int](1 << 20) + dstIds.foreach(dstIdSet.add) + val sortedDstIds = new Array[Int](dstIdSet.size) + var i = 0 + var pos = dstIdSet.nextPos(0) + while (pos != -1) { + sortedDstIds(i) = dstIdSet.getValue(pos) + pos = dstIdSet.nextPos(pos + 1) + i += 1 + } + assert(i == dstIdSet.size) + ju.Arrays.sort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + i = 0 + while (i < sortedDstIds.size) { + dstIdToLocalIndex.update(sortedDstIds(i), i) + i += 1 + } + logDebug( + "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") + val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) + (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) + }.groupByKey(new HashPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks").cache() + val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => + val encoder = new LocalIndexEncoder(dstPart.numPartitions) + val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) + var i = 0 + val seen = new Array[Boolean](dstPart.numPartitions) + while (i < srcIds.size) { + var j = dstPtrs(i) + ju.Arrays.fill(seen, false) + while (j < dstPtrs(i + 1)) { + val dstBlockId = encoder.blockId(dstEncodedIndices(j)) + if (!seen(dstBlockId)) { + activeIds(dstBlockId) += i // add the local index in this out-block + seen(dstBlockId) = true + } + j += 1 + } + i += 1 + } + activeIds.map { x => + x.result() + } + }.setName(prefix + "OutBlocks").cache() + (inBlocks, outBlocks) + } + + /** + * Compute dst factors by constructing and solving least square problems. + * + * @param srcFactorBlocks src factors + * @param srcOutBlocks src out-blocks + * @param dstInBlocks dst in-blocks + * @param rank rank + * @param regParam regularization constant + * @param srcEncoder encoder for src local indices + * @param implicitPrefs whether to use implicit preference + * @param alpha the alpha constant in the implicit preference formulation + * + * @return dst factors + */ + private def computeFactors( + srcFactorBlocks: RDD[(Int, FactorBlock)], + srcOutBlocks: RDD[(Int, OutBlock)], + dstInBlocks: RDD[(Int, InBlock)], + rank: Int, + regParam: Double, + srcEncoder: LocalIndexEncoder, + implicitPrefs: Boolean = false, + alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.size + val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None + val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { + case (srcBlockId, (srcOutBlock, srcFactors)) => + srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => + (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) + } + } + val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + dstInBlocks.join(merged).mapValues { + case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => + val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) + srcFactors.foreach { case (srcBlockId, factors) => + sortedSrcFactors(srcBlockId) = factors + } + val dstFactors = new Array[Array[Float]](dstIds.size) + var j = 0 + val ls = new NormalEquation(rank) + val solver = new CholeskySolver // TODO: add NNLS solver + while (j < dstIds.size) { + ls.reset() + if (implicitPrefs) { + ls.merge(YtY.get) + } + var i = srcPtrs(j) + while (i < srcPtrs(j + 1)) { + val encoded = srcEncodedIndices(i) + val blockId = srcEncoder.blockId(encoded) + val localIndex = srcEncoder.localIndex(encoded) + val srcFactor = sortedSrcFactors(blockId)(localIndex) + val rating = ratings(i) + if (implicitPrefs) { + ls.addImplicit(srcFactor, rating, alpha) + } else { + ls.add(srcFactor, rating) + } + i += 1 + } + dstFactors(j) = solver.solve(ls, regParam) + j += 1 + } + dstFactors + } + } + + /** + * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. + * Caching of the input factors is handled in [[ALS#train]]. + */ + private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { + factorBlocks.values.aggregate(new NormalEquation(rank))( + seqOp = (ne, factors) => { + factors.foreach(ne.add(_, 0.0f)) + ne + }, + combOp = (ne1, ne2) => ne1.merge(ne2)) + } + + /** + * Encoder for storing (blockId, localIndex) into a single integer. + * + * We use the leading bits (including the sign bit) to store the block id and the rest to store + * the local index. This is based on the assumption that users/items are approximately evenly + * partitioned. With this assumption, we should be able to encode two billion distinct values. + * + * @param numBlocks number of blocks + */ + private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { + + require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") + + private[this] final val numLocalIndexBits = + math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) + private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 + + /** Encodes a (blockId, localIndex) into a single integer. */ + def encode(blockId: Int, localIndex: Int): Int = { + require(blockId < numBlocks) + require((localIndex & ~localIndexMask) == 0) + (blockId << numLocalIndexBits) | localIndex + } + + /** Gets the block id from an encoded index. */ + @inline + def blockId(encoded: Int): Int = { + encoded >>> numLocalIndexBits + } + + /** Gets the local index from an encoded index. */ + @inline + def localIndex(encoded: Int): Int = { + encoded & localIndexMask + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index bee951a2e5e26..5f84677be238d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double) * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala new file mode 100644 index 0000000000000..cdd4db1b5b7dc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import java.util.Random + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.scalatest.FunSuite + +import org.apache.spark.Logging +import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { + + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("LocalIndexEncoder") { + val random = new Random + for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { + val encoder = new LocalIndexEncoder(numBlocks) + val maxLocalIndex = Int.MaxValue / numBlocks + val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++ + Seq((0, 0), (numBlocks - 1, maxLocalIndex)) + tests.foreach { case (blockId, localIndex) => + val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex." + val encoded = encoder.encode(blockId, localIndex) + assert(encoder.blockId(encoded) === blockId, err) + assert(encoder.localIndex(encoded) === localIndex, err) + } + } + } + + test("normal equation construction with explict feedback") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 3.0f) + .add(Array(4.0f, 5.0f), 6.0f) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 2) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5") + // b = np.matrix("3; 6") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + + val ne1 = new NormalEquation(2) + .add(Array(7.0f, 8.0f), 9.0f) + ne0.merge(ne1) + assert(ne0.n === 3) + // NumPy code that computes the expected values: + // A = np.matrix("1 2; 4 5; 7 8") + // b = np.matrix("3; 6; 9") + // ata = A.transpose() * A + // atb = A.transpose() * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f), 2.0f) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + } + intercept[IllegalArgumentException] { + val ne2 = new NormalEquation(3) + ne0.merge(ne2) + } + + ne0.reset() + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + } + + test("normal equation construction with implicit feedback") { + val k = 2 + val alpha = 0.5 + val ne0 = new NormalEquation(k) + .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) + .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) + .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) + assert(ne0.k === k) + assert(ne0.triK === k * (k + 1) / 2) + assert(ne0.n === 0) // addImplicit doesn't increase the count. + // NumPy code that computes the expected values: + // alpha = 0.5 + // A = np.matrix("-5 -4; -2 -1; 1 2") + // b = np.matrix("-3; 0; 3") + // b1 = b > 0 + // c = 1.0 + alpha * np.abs(b) + // C = np.diag(c.A1) + // I = np.eye(3) + // ata = A.transpose() * (C - I) * A + // atb = A.transpose() * C * b1 + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) + } + + test("CholeskySolver") { + val k = 2 + val ne0 = new NormalEquation(k) + .add(Array(1.0f, 2.0f), 4.0f) + .add(Array(1.0f, 3.0f), 9.0f) + .add(Array(1.0f, 4.0f), 16.0f) + val ne1 = new NormalEquation(k) + .merge(ne0) + + val chol = new CholeskySolver + val x0 = chol.solve(ne0, 0.0).map(_.toDouble) + // NumPy code that computes the expected solution: + // A = np.matrix("1 2; 1 3; 1 4") + // b = b = np.matrix("3; 6") + // x0 = np.linalg.lstsq(A, b)[0] + assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) + + assert(ne0.n === 0) + assert(ne0.ata.forall(_ == 0.0)) + assert(ne0.atb.forall(_ == 0.0)) + + val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + // NumPy code that computes the expected solution, where lambda is scaled by n: + // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) + } + + test("RatingBlockBuilder") { + val emptyBuilder = new RatingBlockBuilder() + assert(emptyBuilder.size === 0) + val emptyBlock = emptyBuilder.build() + assert(emptyBlock.srcIds.isEmpty) + assert(emptyBlock.dstIds.isEmpty) + assert(emptyBlock.ratings.isEmpty) + + val builder0 = new RatingBlockBuilder() + .add(Rating(0, 1, 2.0f)) + .add(Rating(3, 4, 5.0f)) + assert(builder0.size === 2) + val builder1 = new RatingBlockBuilder() + .add(Rating(6, 7, 8.0f)) + .merge(builder0.build()) + assert(builder1.size === 3) + val block = builder1.build() + val ratings = Seq.tabulate(block.size) { i => + (block.srcIds(i), block.dstIds(i), block.ratings(i)) + }.toSet + assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f))) + } + + test("UncompressedInBlock") { + val encoder = new LocalIndexEncoder(10) + val uncompressed = new UncompressedInBlockBuilder(encoder) + .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) + .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) + .build() + assert(uncompressed.size === 5) + val records = Seq.tabulate(uncompressed.size) { i => + val dstEncodedIndex = uncompressed.dstEncodedIndices(i) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i)) + }.toSet + val expected = + Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f)) + assert(records === expected) + + val compressed = uncompressed.compress() + assert(compressed.size === 5) + assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3)) + assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) + var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] + var i = 0 + while (i < compressed.srcIds.size) { + var j = compressed.dstPtrs(i) + while (j < compressed.dstPtrs(i + 1)) { + val dstEncodedIndex = compressed.dstEncodedIndices(j) + val dstBlockId = encoder.blockId(dstEncodedIndex) + val dstLocalIndex = encoder.localIndex(dstEncodedIndex) + decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j))) + j += 1 + } + i += 1 + } + assert(decompressed.toSet === expected) + } + + /** + * Generates an explicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genExplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val x = random.nextDouble() + if (x < totalFraction) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating] + val test = ArrayBuffer.empty[Rating] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Test ALS using the given training/test splits and parameters. + * @param training training dataset + * @param test test dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @param targetRMSE target test RMSE + */ + def testALS( + training: RDD[Rating], + test: RDD[Rating], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3, + targetRMSE: Double = 0.05): Unit = { + val sqlContext = this.sqlContext + import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + val alpha = als.getAlpha + val model = als.fit(training) + val predictions = model.transform(test) + .select('rating, 'prediction) + .map { case Row(rating: Float, prediction: Float) => + (rating.toDouble, prediction.toDouble) + } + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE + // with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val mse = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + }.mean() + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) + } + + test("exact rank-1 matrix") { + val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1) + testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001) + testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001) + } + + test("approximate rank-1 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01) + testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02) + testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02) + } + + test("approximate rank-2 matrix") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03) + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03) + } + + test("different block settings") { + val (training, test) = + genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) + } + } + + test("more blocks than ratings") { + val (training, test) = + genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002, + numItemBlocks = 5, numUserBlocks = 5) + } + + test("implicit feedback") { + val (training, test) = + genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, + targetRMSE = 0.3) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index f3b7bfda788fa..e9fc37e000526 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk prediciton + * @param bulkPredict flag to test bulk predicition * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into From f383c15b64ad0d674c09b70dd632f9a93fce44f6 Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Fri, 23 Jan 2015 19:03:05 +0000 Subject: [PATCH 20/22] [SQL][SPARK-5309] Add test that exercises dictionary based ParquetConverter --- .../spark/sql/parquet/ParquetQuerySuite2.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala index 4c081fb4510b2..1aa1ea1f7c30d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala @@ -85,4 +85,16 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest { checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } + + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } } From bb9bf987b9b652c411d874f9d0c283ba4fe68671 Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Wed, 21 Jan 2015 17:06:25 +0000 Subject: [PATCH 21/22] SPARK-5309: Use Dictionary for Binary->String conversion Use of Dictionary will reduce conversion overhead which is significant when converting from Binary in UTF to Java String. Dictionaries are use in Parquet where columns hold a limited set of values which is often the case. --- .../spark/sql/parquet/ParquetConverter.scala | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 9d9150246c8d4..b1df6e5d9b9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.parquet + +import parquet.column.Dictionary + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -102,12 +105,8 @@ private[sql] object CatalystConverter { } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately - case StringType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value) - } - } + case StringType => + new CatalystPrimitiveStringConverter(parent, fieldIndex) case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = @@ -197,8 +196,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + updateField(fieldIndex, value) protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) @@ -384,8 +383,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - current.setString(fieldIndex, value.toStringUsingUTF8) + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + current.setString(fieldIndex, value) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -426,6 +425,33 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. + * Supports dictionaries to reduce Binary to String conversion overhead. + * + * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) + extends CatalystPrimitiveConverter(parent, fieldIndex) { + + private var dict: Array[String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary):Unit = + dict = (for (i <- 0 to dictionary.getMaxId) + yield dictionary.decodeToBinary(i).toStringUsingUTF8).toArray + + override def addValueFromDictionary(dictionaryId: Int): Unit = + parent.updateString(fieldIndex, dict(dictionaryId)) + + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value.toStringUsingUTF8) +} + private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } @@ -583,9 +609,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { checkGrowBuffer() - buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + buffer(elements) = value.asInstanceOf[NativeType] elements += 1 } From bbf93a4d3c516698e78f56ea73696245a73f2edb Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Fri, 23 Jan 2015 19:03:05 +0000 Subject: [PATCH 22/22] [SQL][SPARK-5309] Add test that exercises dictionary based ParquetConverter --- .../apache/spark/sql/parquet/ParquetQuerySuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 1263ff818ea19..51ef540522b36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -85,4 +85,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } + + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } }