From 95e032c9c3d71aa9095c8193361f3eefb91e9a0c Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 25 May 2017 17:09:12 -0700 Subject: [PATCH 001/133] add SQL trunc function --- R/pkg/R/functions.R | 26 +++++++++++++++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 ++ 2 files changed, 28 insertions(+) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 06a90192bb12f..3f238359ea7f9 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4015,3 +4015,29 @@ setMethod("input_file_name", signature("missing"), jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") column(jc) }) + +#' trunc +#' +#' Returns date truncated to the unit specified by the format. +#' +#' @param x Column to compute on. +#' @param format string used for specify the truncation method. For example, +#' "year", "yyyy", "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' +#' @rdname trunc +#' @name trunc +#' @family date time functions +#' @aliases trunc,Column-method +#' @export +#' @examples +#' \dontrun{ +#' trunc(df$c, "year") +#' trunc(df$c, "month") +#' } +#' @note trunc since 2.3.0 +setMethod("trunc", + signature(x = "Column"), + function(x, format = "year") { + jc <- callJStatic("org.apache.spark.sql.functions", "trunc", x@jc, format) + column(jc) + }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b633b78d5bb4d..92d0f000af656 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1404,6 +1404,8 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) + c23 <- trunc(c) + trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From c1e7989c4ffd83c51f5c97998b4ff6fe8dd83cf4 Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Fri, 26 May 2017 09:25:43 +0800 Subject: [PATCH 002/133] [SPARK-20888][SQL][DOCS] Document change of default setting of spark.sql.hive.caseSensitiveInferenceMode (Link to Jira: https://issues.apache.org/jira/browse/SPARK-20888) ## What changes were proposed in this pull request? Document change of default setting of spark.sql.hive.caseSensitiveInferenceMode configuration key from NEVER_INFO to INFER_AND_SAVE in the Spark SQL 2.1 to 2.2 migration notes. Author: Michael Allman Closes #18112 from mallman/spark-20888-document_infer_and_save. --- docs/sql-programming-guide.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 490c1ce8a7cc5..adb12d2489a57 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1223,7 +1223,7 @@ the following case-insensitive options: This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing. - + createTableColumnTypes @@ -1444,6 +1444,10 @@ options. # Migration Guide +## Upgrading From Spark SQL 2.1 to 2.2 + + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + ## Upgrading From Spark SQL 2.0 to 2.1 - Datasource tables now store partition metadata in the Hive metastore. This means that Hive DDLs such as `ALTER TABLE PARTITION ... SET LOCATION` are now available for tables created with the Datasource API. From 2dbe0c5288b48733bae0e39a6c5d8047f4a55088 Mon Sep 17 00:00:00 2001 From: setjet Date: Fri, 26 May 2017 10:21:39 +0800 Subject: [PATCH 003/133] [SPARK-20775][SQL] Added scala support from_json ## What changes were proposed in this pull request? from_json function required to take in a java.util.Hashmap. For other functions, a java wrapper is provided which casts a java hashmap to a scala map. Only a java function is provided in this case, forcing scala users to pass in a java.util.Hashmap. Added the missing wrapper. ## How was this patch tested? Added a unit test for passing in a scala map Author: setjet Closes #18094 from setjet/spark-20775. --- .../org/apache/spark/sql/functions.scala | 22 +++++++++++++++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 9 +++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 36c0f18b6e2e3..7eea6d8d85b6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3060,8 +3060,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3072,6 +3073,23 @@ object functions { * @since 2.1.0 */ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + from_json(e, schema, options.asScala.toMap) + } + + /** + * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. + * + * @group collection_funcs + * @since 2.3.0 + */ + def from_json(e: Column, schema: String, options: Map[String, String]): Column = { val dataType = try { DataType.fromJson(schema) } catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 69a500c845a7b..cf2d00fc94423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,13 +156,20 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } - test("from_json uses DDL strings for defining a schema") { + test("from_json uses DDL strings for defining a schema - java") { val df = Seq("""{"a": 1, "b": "haa"}""").toDS() checkAnswer( df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), Row(Row(1, "haa")) :: Nil) } + test("from_json uses DDL strings for defining a schema - scala") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", Map[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") From f47700c9cadd72a2495f97f250790449705f631f Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Fri, 26 May 2017 10:44:40 +0800 Subject: [PATCH 004/133] [SPARK-14659][ML] RFormula consistent with R when handling strings ## What changes were proposed in this pull request? When handling strings, the category dropped by RFormula and R are different: - RFormula drops the least frequent level - R drops the first level after ascending alphabetical ordering This PR supports different string ordering types in StringIndexer #17879 so that RFormula can drop the same level as R when handling strings using`stringOrderType = "alphabetDesc"`. ## How was this patch tested? new tests Author: Wayne Zhang Closes #17967 from actuaryzhang/RFormula. --- .../apache/spark/ml/feature/RFormula.scala | 44 +++++++++- .../spark/ml/feature/StringIndexer.scala | 4 +- .../spark/ml/feature/RFormulaSuite.scala | 84 +++++++++++++++++++ 3 files changed, 129 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5a3e2929f5f52..1fad0a6fc9443 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -37,6 +37,42 @@ import org.apache.spark.sql.types._ */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { + /** + * Param for how to order categories of a string FEATURE column used by `StringIndexer`. + * The last category after ordering is dropped when encoding strings. + * Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'. + * The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula` + * drops the same category as R when encoding strings. + * + * The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`: + * {{{ + * +-----------------+---------------------------------------+----------------------------------+ + * | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula | + * +-----------------+---------------------------------------+----------------------------------+ + * | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') | + * | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') | + * | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')| + * | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') | + * +-----------------+---------------------------------------+----------------------------------+ + * }}} + * Note that this ordering option is NOT used for the label column. When the label column is + * indexed, it uses the default descending frequency ordering in `StringIndexer`. + * + * @group param + */ + @Since("2.3.0") + final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType", + "How to order categories of a string FEATURE column used by StringIndexer. " + + "The last category after ordering is dropped when encoding strings. " + + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " + + "The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " + + "RFormula drops the same category as R when encoding strings.", + ParamValidators.inArray(StringIndexer.supportedStringOrderType)) + + /** @group getParam */ + @Since("2.3.0") + def getStringIndexerOrderType: String = $(stringIndexerOrderType) + protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) } @@ -125,6 +161,11 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("2.1.0") def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) + /** @group setParam */ + @Since("2.3.0") + def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value) + setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") @@ -155,6 +196,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new StringIndexer() .setInputCol(term) .setOutputCol(indexCol) + .setStringOrderType($(stringIndexerOrderType)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b2dc4fcb61964..dfc902bd0b0f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -47,7 +47,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * @group param */ @Since("1.6.0") - val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + "invalid data (unseen labels or NULL values). " + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", @@ -73,7 +73,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Since("2.3.0") final val stringOrderType: Param[String] = new Param(this, "stringOrderType", - "how to order labels of string column. " + + "How to order labels of string column. " + "The first label after ordering is assigned an index of 0. " + s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.", ParamValidators.inArray(StringIndexer.supportedStringOrderType)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index fbebd75d70ac5..41d0062c2cabd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("encodes string terms with string indexer order type") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) + .toDF("id", "a", "b") + + val expected = Seq( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "aaz", 5, Vectors.dense(0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "features", "label"), + Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 0.0, 5.0), 3.0), + (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) + ).toDF("id", "a", "b", "features", "label"), + Seq( + (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), + (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) + ).toDF("id", "a", "b", "features", "label"), + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), + (4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0) + ).toDF("id", "a", "b", "features", "label") + ) + + var idx = 0 + for (orderType <- StringIndexer.supportedStringOrderType) { + val model = formula.setStringIndexerOrderType(orderType).fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected(idx).collect()) + idx += 1 + } + } + + test("test consistency with R when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "aaz"), + b = c(4, 4, 5, 5)) + model.matrix(id ~ a + b, df)[, -1] + + abar afoo b + 0 1 4 + 1 0 4 + 1 0 5 + 0 0 5 + */ + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5)) + .toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + + /* + Note that the category dropped after encoding is the same between R and Spark + (i.e., "aaz" is treated as the reference level). + However, the column order is still different: + R renders the columns in ascending alphabetical order ("bar", "foo"), while + RFormula renders the columns in descending alphabetical order ("foo", "bar"). + */ + val expected = Seq( + (1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0), + (4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) + ).toDF("id", "a", "b", "features", "label") + + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = From 9bce1560b75b65c4bf3eed4e4d4384a4dac81397 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 25 May 2017 21:54:02 -0700 Subject: [PATCH 005/133] minor change --- R/pkg/R/functions.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3f238359ea7f9..e80a822649597 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4021,8 +4021,8 @@ setMethod("input_file_name", signature("missing"), #' Returns date truncated to the unit specified by the format. #' #' @param x Column to compute on. -#' @param format string used for specify the truncation method. For example, -#' "year", "yyyy", "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. +#' @param format string used for specify the truncation method. For example, "year", "yyyy", +#' "yy" for truncate by year, or "month", "mon", "mm" for truncate by month. #' #' @rdname trunc #' @name trunc From 4e357fda98a7e67bf5840bcb426de0c0828e0144 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 25 May 2017 22:05:42 -0700 Subject: [PATCH 006/133] remove default value --- R/pkg/R/functions.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e80a822649597..08815c0d53d80 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4037,7 +4037,7 @@ setMethod("input_file_name", signature("missing"), #' @note trunc since 2.3.0 setMethod("trunc", signature(x = "Column"), - function(x, format = "year") { + function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "trunc", x@jc, format) column(jc) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 92d0f000af656..fd67cc4bfa753 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1404,7 +1404,7 @@ test_that("column functions", { c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) c22 <- not(c) - c23 <- trunc(c) + trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + + c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") + trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm") # Test if base::is.nan() is exposed From ccdcf5c2e75e2b56bc7f95fc66ef50fe3237925b Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Thu, 25 May 2017 22:10:13 -0700 Subject: [PATCH 007/133] fix example --- R/pkg/R/functions.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 08815c0d53d80..a6adcf4e9e59a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4032,7 +4032,9 @@ setMethod("input_file_name", signature("missing"), #' @examples #' \dontrun{ #' trunc(df$c, "year") +#' trunc(df$c, "yy") #' trunc(df$c, "month") +#' trunc(df$c, "mon") #' } #' @note trunc since 2.3.0 setMethod("trunc", From 8ce0d8ffb68bd9e89c23d3a026308dcc039a1b1d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 May 2017 13:45:55 +0800 Subject: [PATCH 008/133] [SPARK-20392][SQL] Set barrier to prevent re-entering a tree ## What changes were proposed in this pull request? It is reported that there is performance downgrade when applying ML pipeline for dataset with many columns but few rows. A big part of the performance downgrade comes from some operations (e.g., `select`) on DataFrame/Dataset which re-create new DataFrame/Dataset with a new `LogicalPlan`. The cost can be ignored in the usage of SQL, normally. However, it's not rare to chain dozens of pipeline stages in ML. When the query plan grows incrementally during running those stages, the total cost spent on re-creation of DataFrame grows too. In particular, the `Analyzer` will go through the big query plan even most part of it is analyzed. By eliminating part of the cost, the time to run the example code locally is reduced from about 1min to about 30 secs. In particular, the time applying the pipeline locally is mostly spent on calling transform of the 137 `Bucketizer`s. Before the change, each call of `Bucketizer`'s transform can cost about 0.4 sec. So the total time spent on all `Bucketizer`s' transform is about 50 secs. After the change, each call only costs about 0.1 sec. We also make `boundEnc` as lazy variable to reduce unnecessary running time. ### Performance improvement The codes and datasets provided by Barry Becker to re-produce this issue and benchmark can be found on the JIRA. Before this patch: about 1 min After this patch: about 20 secs ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17770 from viirya/SPARK-20392. --- .../sql/catalyst/analysis/Analyzer.scala | 75 +++++++++------ .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++--- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 35 ------- .../plans/logical/basicLogicalOperators.scala | 9 ++ .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++ .../sql/catalyst/plans/LogicalPlanSuite.scala | 26 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 92 ++++++++++--------- .../execution/datasources/DataSource.scala | 2 +- .../sql/execution/datasources/rules.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- 16 files changed, 151 insertions(+), 144 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d130962c63918..85cf8ddbaacf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -166,14 +166,15 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases) + CleanupAliases, + EliminateBarriers) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -201,7 +202,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -243,7 +244,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -615,7 +616,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -670,7 +671,9 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { + // Remove analysis barrier if any. + val right = EliminateBarriers(oriRight) val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -713,7 +716,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - right + oriRight case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -726,7 +729,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - newRight + AnalysisBarrier(newRight) } } @@ -787,7 +790,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -961,7 +964,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1017,7 +1020,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1041,11 +1044,13 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1066,7 +1071,8 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => + val child = EliminateBarriers(orgChild) try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1093,7 +1099,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return plan + return AnalysisBarrier(plan) } plan match { case p: Project => @@ -1165,7 +1171,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1504,7 +1510,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1519,7 +1525,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1545,7 +1551,9 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => + apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1605,6 +1613,8 @@ class Analyzer( case ae: AnalysisException => filter } + case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1717,7 +1727,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1775,7 +1785,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2092,7 +2102,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2137,7 +2147,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2202,7 +2212,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2267,7 +2277,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2353,7 +2363,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2387,7 +2397,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2441,7 +2451,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2470,6 +2480,13 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** Remove the barrier nodes of analysis */ +object EliminateBarriers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case AnalysisBarrier(child) => child + } +} + /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. @@ -2519,7 +2536,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..ac72bc4ef4200 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // fix decimal precision for expressions case q => q.transformExpressions( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index dad1340571cc8..40675359bec47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37a95..c3645170589c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -206,7 +206,7 @@ object TypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -261,7 +261,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if p.analyzed => p case s @ SetOperation(left, right) if s.childrenResolved && @@ -335,7 +335,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -391,7 +391,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -449,7 +449,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -490,7 +490,7 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -580,7 +580,7 @@ object TypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -602,7 +602,7 @@ object TypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -632,7 +632,7 @@ object TypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -656,7 +656,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -673,7 +673,7 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index a27aa845bf0ae..af1f9165b0044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.resolveExpressions(transformTimeZoneExprs) + plan.transformAllExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index ea46dd7282401..3bbe41cf8f15e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2a3e07aebe709..46d1aac1857d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -236,7 +236,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2ebb2ff323c6b..23520eb82b043 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -46,41 +46,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already - * been marked as analyzed. - * - * @param rule the function use to transform this nodes children - */ - def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { - if (!analyzed) { - val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[LogicalPlan]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) - } - } - } else { - this - } - } - - /** - * Recursively transforms the expressions of a tree, skipping nodes that have already - * been analyzed. - */ - def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { - this resolveOperators { - case p => p.transformExpressions(r) - } - } - /** A cache for the estimated statistics, such that it will only be computed once. */ private var statsCache: Option[Statistics] = None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6878b6b179c3a..b8c2f7670d7b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -896,3 +897,11 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } + +/** A logical plan for setting a barrier of analysis */ +case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override def output: Seq[Attribute] = child.output + override def analyzed: Boolean = true + override def isStreaming: Boolean = child.isStreaming + override lazy val canonicalized: LogicalPlan = child.canonicalized +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0896caeab8d7a..3b4289767ad0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -441,6 +441,20 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + test("analysis barrier") { + // [[AnalysisBarrier]] will be removed after analysis + checkAnalysis( + Project(Seq(UnresolvedAttribute("tbl.a")), + AnalysisBarrier(SubqueryAlias("tbl", testRelation))), + Project(testRelation.output, SubqueryAlias("tbl", testRelation))) + + // Verify we won't go through a plan wrapped in a barrier. + // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. + val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), + SubqueryAlias("tbl", testRelation))) + assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) + } + test("SPARK-20311 range(N) as alias") { def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index cc86f1f6e2f48..215db848383eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly - * skips sub-trees that have already been marked as analyzed. + * This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure + * it can correctly skip sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -36,37 +36,35 @@ class LogicalPlanSuite extends SparkFunSuite { private val testRelation = LocalRelation() - test("resolveOperator runs on operators") { + test("transformUp runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) - plan resolveOperators function + plan transformUp function assert(invocationCount === 1) } - test("resolveOperator runs on operators recursively") { + test("transformUp runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) - plan resolveOperators function + plan transformUp function assert(invocationCount === 2) } - test("resolveOperator skips all ready resolved plans") { + test("transformUp skips all ready resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan = Project(Nil, Project(Nil, testRelation)) - plan.foreach(_.setAnalyzed()) - plan resolveOperators function + val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) + plan transformUp function assert(invocationCount === 0) } - test("resolveOperator skips partially resolved plans") { + test("transformUp skips partially resolved plans wrapped in analysis barrier") { invocationCount = 0 - val plan1 = Project(Nil, testRelation) + val plan1 = AnalysisBarrier(Project(Nil, testRelation)) val plan2 = Project(Nil, plan1) - plan1.foreach(_.setAnalyzed()) - plan2 resolveOperators function + plan2 transformUp function assert(invocationCount === 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cbab029b87b2a..f9bd8f3d278ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -187,6 +187,9 @@ class Dataset[T] private[sql]( } } + // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. + @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -413,7 +416,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -616,7 +619,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) } /** @@ -790,7 +793,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) } /** @@ -868,7 +871,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -929,7 +932,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -938,8 +941,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed + val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -971,7 +974,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None) + Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) } /** @@ -1003,8 +1006,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.logicalPlan, - other.logicalPlan, + this.planWithBarrier, + other.planWithBarrier, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1174,7 +1177,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - UnresolvedHint(name, parameters, logicalPlan) + UnresolvedHint(name, parameters, planWithBarrier) } /** @@ -1200,7 +1203,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, planWithBarrier) } /** @@ -1238,7 +1241,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), logicalPlan) + Project(cols.map(_.named), planWithBarrier) } /** @@ -1293,8 +1296,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, - logicalPlan) + val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, + planWithBarrier) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1312,8 +1315,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, logicalPlan.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) + columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1389,7 +1392,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + Filter(condition.expr, planWithBarrier) } /** @@ -1566,7 +1569,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = planWithBarrier val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1712,7 +1715,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + Limit(Literal(n), planWithBarrier) } /** @@ -1741,7 +1744,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) + CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) } /** @@ -1755,7 +1758,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(logicalPlan, other.logicalPlan) + Intersect(planWithBarrier, other.planWithBarrier) } /** @@ -1769,7 +1772,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(logicalPlan, other.logicalPlan) + Except(planWithBarrier, other.planWithBarrier) } /** @@ -1790,7 +1793,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() } } @@ -1832,15 +1835,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = logicalPlan.output + val sortOrder = planWithBarrier.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, logicalPlan) + Sort(sortOrder, global = false, planWithBarrier) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - logicalPlan + planWithBarrier } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1924,7 +1927,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -1965,7 +1968,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + qualifier = None, generatorOutput = Nil, planWithBarrier) } } @@ -2080,7 +2083,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output + val attrs = this.planWithBarrier.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2128,7 +2131,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, logicalPlan, isStreaming) + Deduplicate(groupCols, planWithBarrier, isStreaming) } /** @@ -2277,7 +2280,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2291,7 +2294,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, logicalPlan)) + withTypedPlan(TypedFilter(func, planWithBarrier)) } /** @@ -2305,7 +2308,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, planWithBarrier) } /** @@ -2320,7 +2323,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, planWithBarrier)) } /** @@ -2336,7 +2339,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, planWithBarrier), implicitly[Encoder[U]]) } @@ -2367,7 +2370,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) } /** @@ -2522,7 +2525,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + Repartition(numPartitions, shuffle = true, planWithBarrier) } /** @@ -2536,7 +2539,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) } /** @@ -2552,7 +2555,8 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), planWithBarrier, + sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2573,7 +2577,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + Repartition(numPartitions, shuffle = false, planWithBarrier) } /** @@ -2662,7 +2666,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](planWithBarrier) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2761,7 +2765,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = logicalPlan, + child = planWithBarrier, allowExisting = false, replace = replace, viewType = viewType) @@ -2932,7 +2936,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, logicalPlan) + Sort(sortOrder, global = global, planWithBarrier) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 14c40605ea31c..9fce29b06b9d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -424,7 +424,7 @@ case class DataSource( }.head } // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This + // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1eb..5f65898f5312e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -38,7 +38,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4d155d538d637..d02c8ffe33f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext { test("collapse adjacent repartitions") { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length - assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9c60d22d35ce1..662fc80661513 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,7 +88,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -115,7 +115,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -146,7 +146,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) From a97c497045e9102b8eefcd0a0567ee08e61c838c Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 25 May 2017 23:00:50 -0700 Subject: [PATCH 009/133] [SPARK-20849][DOC][SPARKR] Document R DecisionTree ## What changes were proposed in this pull request? 1, add an example for sparkr `decisionTree` 2, document it in user guide ## How was this patch tested? local submit Author: Zheng RuiFeng Closes #18067 from zhengruifeng/dt_example. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 50 ++++++++++++++------- docs/ml-classification-regression.md | 7 +++ docs/sparkr.md | 1 + examples/src/main/r/ml/decisionTree.R | 65 +++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 15 deletions(-) create mode 100644 examples/src/main/r/ml/decisionTree.R diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 13a399165c8b4..2301a64576d0e 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -503,6 +503,8 @@ SparkR supports the following machine learning models and algorithms. #### Tree - Classification and Regression +* Decision Tree + * Gradient-Boosted Trees (GBT) * Random Forest @@ -776,16 +778,32 @@ newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) head(predict(isoregModel, newDF)) ``` +#### Decision Tree + +`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +We use the `Titanic` dataset to train a decision tree and make predictions: + +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) +summary(dtModel) +predictions <- predict(dtModel, df) +``` + #### Gradient-Boosted Trees `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -We use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `Titanic` dataset to train a gradient-boosted tree and make predictions: -```{r, warning=FALSE} -df <- createDataFrame(longley) -gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) ``` @@ -795,11 +813,12 @@ predictions <- predict(gbtModel, df) `spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -In the following example, we use the `longley` dataset to train a random forest and make predictions: +In the following example, we use the `Titanic` dataset to train a random forest and make predictions: -```{r, warning=FALSE} -df <- createDataFrame(longley) -rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) ``` @@ -965,17 +984,18 @@ Given a `SparkDataFrame`, the test compares continuous data in a given column `t specified by parameter `nullHypothesis`. Users can call `summary` to get a summary of the test results. -In the following example, we test whether the `longley` dataset's `Armed_Forces` column +In the following example, we test whether the `Titanic` dataset's `Freq` column follows a normal distribution. We set the parameters of the normal distribution using the mean and standard deviation of the sample. -```{r, warning=FALSE} -df <- createDataFrame(longley) -afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) -afMean <- afStats[1] -afStd <- afStats[2] +```{r} +t <- as.data.frame(Titanic) +df <- createDataFrame(t) +freqStats <- head(select(df, mean(df$Freq), sd(df$Freq))) +freqMean <- freqStats[1] +freqStd <- freqStats[2] -test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +test <- spark.kstest(df, "Freq", "norm", c(freqMean, freqStd)) testSummary <- summary(test) testSummary ``` diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index ab6f587e09ef2..083df2e405d62 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -708,6 +708,13 @@ More details on parameters can be found in the [Python API documentation](api/py {% include_example python/ml/decision_tree_regression_example.py %} +
+ +Refer to the [R API docs](api/R/spark.decisionTree.html) for more details. + +{% include_example regression r/ml/decisionTree.R %} +
+ diff --git a/docs/sparkr.md b/docs/sparkr.md index 569b85e72c3cf..a3254e7654134 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -492,6 +492,7 @@ SparkR supports the following machine learning algorithms currently: #### Tree +* [`spark.decisionTree`](api/R/spark.decisionTree.html): `Decision Tree for` [`Regression`](ml-classification-regression.html#decision-tree-regression) `and` [`Classification`](ml-classification-regression.html#decision-tree-classifier) * [`spark.gbt`](api/R/spark.gbt.html): `Gradient Boosted Trees for` [`Regression`](ml-classification-regression.html#gradient-boosted-tree-regression) `and` [`Classification`](ml-classification-regression.html#gradient-boosted-tree-classifier) * [`spark.randomForest`](api/R/spark.randomForest.html): `Random Forest for` [`Regression`](ml-classification-regression.html#random-forest-regression) `and` [`Classification`](ml-classification-regression.html#random-forest-classifier) diff --git a/examples/src/main/r/ml/decisionTree.R b/examples/src/main/r/ml/decisionTree.R new file mode 100644 index 0000000000000..9e10ae5519cd3 --- /dev/null +++ b/examples/src/main/r/ml/decisionTree.R @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/decisionTree.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-decisionTree-example") + +# DecisionTree classification model + +# $example on:classification$ +# Load training data +df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a DecisionTree classification model with spark.decisionTree +model <- spark.decisionTree(training, label ~ features, "classification") + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +head(predictions) +# $example off:classification$ + +# DecisionTree regression model + +# $example on:regression$ +# Load training data +df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a DecisionTree regression model with spark.decisionTree +model <- spark.decisionTree(training, label ~ features, "regression") + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +head(predictions) +# $example off:regression$ + +sparkR.session.stop() From d9ad78908f6189719cec69d34557f1a750d2e6af Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 May 2017 15:01:28 +0800 Subject: [PATCH 010/133] [SPARK-20868][CORE] UnsafeShuffleWriter should verify the position after FileChannel.transferTo ## What changes were proposed in this pull request? Long time ago we fixed a [bug](https://issues.apache.org/jira/browse/SPARK-3948) in shuffle writer about `FileChannel.transferTo`. We were not very confident about that fix, so we added a position check after the writing, try to discover the bug earlier. However this checking is missing in the new `UnsafeShuffleWriter`, this PR adds it. https://issues.apache.org/jira/browse/SPARK-18105 maybe related to that `FileChannel.transferTo` bug, hopefully we can find out the root cause after adding this position check. ## How was this patch tested? N/A Author: Wenchen Fan Closes #18091 from cloud-fan/shuffle. --- .../shuffle/sort/UnsafeShuffleWriter.java | 15 ++-- .../scala/org/apache/spark/util/Utils.scala | 71 +++++++++++-------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 8a1771848dee6..2fde5c300f072 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -422,17 +422,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( - spillInputChannelPositions[i], - bytesToTransfer, - mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; - } + Utils.copyFileStreamNIO( + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index edfe229792323..ad39c74a0e232 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,7 +22,7 @@ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInf import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.Channels +import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} @@ -60,7 +60,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.util.logging.RollingFileAppender /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -319,41 +318,22 @@ private[spark] object Utils extends Logging { * copying is disabled by default unless explicitly set transferToEnabled as true, * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false, - transferToEnabled: Boolean = false): Long = - { - var count = 0L + def copyStream( + in: InputStream, + out: OutputStream, + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. val inChannel = in.asInstanceOf[FileInputStream].getChannel() val outChannel = out.asInstanceOf[FileOutputStream].getChannel() - val initialPos = outChannel.position() val size = inChannel.size() - - // In case transferTo method transferred less data than we have required. - while (count < size) { - count += inChannel.transferTo(count, size - count, outChannel) - } - - // Check the position after transferTo loop to see if it is in the right position and - // give user information if not. - // Position will not be increased to the expected length after calling transferTo in - // kernel version 2.6.32, this issue can be seen in - // https://bugs.openjdk.java.net/browse/JDK-7052359 - // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = outChannel.position() - assert(finalPos == initialPos + size, - s""" - |Current position $finalPos do not equal to expected position ${initialPos + size} - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + copyFileStreamNIO(inChannel, outChannel, 0, size) + size } else { + var count = 0L val buf = new Array[Byte](8192) var n = 0 while (n != -1) { @@ -363,8 +343,8 @@ private[spark] object Utils extends Logging { count += n } } + count } - count } { if (closeStreams) { try { @@ -376,6 +356,37 @@ private[spark] object Utils extends Logging { } } + def copyFileStreamNIO( + input: FileChannel, + output: FileChannel, + startPosition: Long, + bytesToCopy: Long): Unit = { + val initialPos = output.position() + var count = 0L + // In case transferTo method transferred less data than we have required. + while (count < bytesToCopy) { + count += input.transferTo(count + startPosition, bytesToCopy - count, output) + } + assert(count == bytesToCopy, + s"request to copy $bytesToCopy bytes, but actually copied $count bytes.") + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + val finalPos = output.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } + /** * Construct a URI container information used for authentication. * This also sets the default authenticator to properly negotiation the From b6f2017a6a5da5ce5aea85934b9df6bc6dcb32e1 Mon Sep 17 00:00:00 2001 From: Wil Selwood Date: Fri, 26 May 2017 11:29:52 +0100 Subject: [PATCH 011/133] [MINOR] document edge case of updateFunc usage ## What changes were proposed in this pull request? Include documentation of the fact that the updateFunc is sometimes called with no new values. This is documented in the main documentation here: https://spark.apache.org/docs/latest/streaming-programming-guide.html#updatestatebykey-operation however from the docs included with the code it is not clear that this is the case. ## How was this patch tested? PR only changes comments. Confirmed code still builds. Author: Wil Selwood Closes #18088 from wselwood/note-edge-case-in-docs. --- .../spark/streaming/dstream/PairDStreamFunctions.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index f38c1e7996595..dcb51d72fa588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -389,6 +389,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -403,6 +404,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -419,6 +421,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -440,6 +443,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. Note, that this function may generate a different * tuple with a different key than the input key. Therefore keys may be removed @@ -464,6 +468,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. @@ -487,6 +492,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. Note, that this function may generate a different * tuple with a different key than the input key. Therefore keys may be removed @@ -513,6 +519,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. + * In every batch the updateFunc will be called for each state even if there are no new values. * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. From 629f38e171409da614fd635bd8dd951b7fde17a4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 May 2017 21:13:38 +0800 Subject: [PATCH 012/133] [SPARK-20887][CORE] support alternative keys in ConfigBuilder ## What changes were proposed in this pull request? `ConfigBuilder` builds `ConfigEntry` which can only read value with one key, if we wanna change the config name but still keep the old one, it's hard to do. This PR introduce `ConfigBuilder.withAlternative`, to support reading config value with alternative keys. And also rename `spark.scheduler.listenerbus.eventqueue.size` to `spark.scheduler.listenerbus.eventqueue.capacity` with this feature, according to https://github.com/apache/spark/pull/14269#discussion_r118432313 ## How was this patch tested? a new test Author: Wenchen Fan Closes #18110 from cloud-fan/config. --- .../scala/org/apache/spark/SparkConf.scala | 2 ++ .../spark/internal/config/ConfigBuilder.scala | 24 ++++++++----- .../spark/internal/config/ConfigEntry.scala | 36 +++++++++++-------- .../internal/config/ConfigProvider.scala | 16 ++------- .../spark/internal/config/ConfigReader.scala | 18 +++++++++- .../spark/internal/config/package.scala | 6 ++-- .../spark/scheduler/LiveListenerBus.scala | 15 ++------ .../internal/config/ConfigEntrySuite.scala | 27 ++++++++++++++ 8 files changed, 92 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 956724b14bba3..ba7a65f79c414 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -592,6 +592,8 @@ private[spark] object SparkConf extends Logging { * * The alternates are used in the order defined in this map. If deprecated configs are * present in the user's configuration, a warning is logged. + * + * TODO: consolidate it with `ConfigBuilder.withAlternative`. */ private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( "spark.executor.userClassPathFirst" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index e5d60a7ef0984..515c9c24a9e2f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -126,8 +126,8 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] that does not have a default value. */ def createOptional: OptionalConfigEntry[T] = { - val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, - parent._public) + val entry = new OptionalConfigEntry[T](parent.key, parent._alternatives, converter, + stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -140,8 +140,8 @@ private[spark] class TypedConfigBuilder[T]( createWithDefaultString(default.asInstanceOf[String]) } else { val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, - stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefault[T](parent.key, parent._alternatives, + transformedDefault, converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -149,8 +149,8 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] with a function to determine the default value */ def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, - stringConverter, parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, parent._alternatives, defaultFunc, + converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_ (entry)) entry } @@ -160,8 +160,8 @@ private[spark] class TypedConfigBuilder[T]( * [[String]] and must be a valid value for the entry. */ def createWithDefaultString(default: String): ConfigEntry[T] = { - val entry = new ConfigEntryWithDefaultString[T](parent.key, default, converter, stringConverter, - parent._doc, parent._public) + val entry = new ConfigEntryWithDefaultString[T](parent.key, parent._alternatives, default, + converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry } @@ -180,6 +180,7 @@ private[spark] case class ConfigBuilder(key: String) { private[config] var _public = true private[config] var _doc = "" private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None + private[config] var _alternatives = List.empty[String] def internal(): ConfigBuilder = { _public = false @@ -200,6 +201,11 @@ private[spark] case class ConfigBuilder(key: String) { this } + def withAlternative(key: String): ConfigBuilder = { + _alternatives = _alternatives :+ key + this + } + def intConf: TypedConfigBuilder[Int] = { new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) } @@ -229,7 +235,7 @@ private[spark] case class ConfigBuilder(key: String) { } def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { - new FallbackConfigEntry(key, _doc, _public, fallback) + new FallbackConfigEntry(key, _alternatives, _doc, _public, fallback) } def regexConf: TypedConfigBuilder[Regex] = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index e86712e84d6ac..f1190289244e9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -41,6 +41,7 @@ package org.apache.spark.internal.config */ private[spark] abstract class ConfigEntry[T] ( val key: String, + val alternatives: List[String], val valueConverter: String => T, val stringConverter: T => String, val doc: String, @@ -52,6 +53,10 @@ private[spark] abstract class ConfigEntry[T] ( def defaultValueString: String + protected def readString(reader: ConfigReader): Option[String] = { + alternatives.foldLeft(reader.get(key))((res, nextKey) => res.orElse(reader.get(nextKey))) + } + def readFrom(reader: ConfigReader): T def defaultValue: Option[T] = None @@ -59,63 +64,64 @@ private[spark] abstract class ConfigEntry[T] ( override def toString: String = { s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)" } - } private class ConfigEntryWithDefault[T] ( key: String, + alternatives: List[String], _defaultValue: T, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(_defaultValue) override def defaultValueString: String = stringConverter(_defaultValue) def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(_defaultValue) + readString(reader).map(valueConverter).getOrElse(_defaultValue) } } private class ConfigEntryWithDefaultFunction[T] ( key: String, + alternatives: List[String], _defaultFunction: () => T, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(_defaultFunction()) override def defaultValueString: String = stringConverter(_defaultFunction()) def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + readString(reader).map(valueConverter).getOrElse(_defaultFunction()) } } private class ConfigEntryWithDefaultString[T] ( key: String, + alternatives: List[String], _defaultValue: String, valueConverter: String => T, stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, alternatives, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(valueConverter(_defaultValue)) override def defaultValueString: String = _defaultValue def readFrom(reader: ConfigReader): T = { - val value = reader.get(key).getOrElse(reader.substitute(_defaultValue)) + val value = readString(reader).getOrElse(reader.substitute(_defaultValue)) valueConverter(value) } - } @@ -124,19 +130,20 @@ private class ConfigEntryWithDefaultString[T] ( */ private[spark] class OptionalConfigEntry[T]( key: String, + alternatives: List[String], val rawValueConverter: String => T, val rawStringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), + extends ConfigEntry[Option[T]](key, alternatives, + s => Some(rawValueConverter(s)), v => v.map(rawStringConverter).orNull, doc, isPublic) { override def defaultValueString: String = "" override def readFrom(reader: ConfigReader): Option[T] = { - reader.get(key).map(rawValueConverter) + readString(reader).map(rawValueConverter) } - } /** @@ -144,17 +151,18 @@ private[spark] class OptionalConfigEntry[T]( */ private class FallbackConfigEntry[T] ( key: String, + alternatives: List[String], doc: String, isPublic: Boolean, private[config] val fallback: ConfigEntry[T]) - extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + extends ConfigEntry[T](key, alternatives, + fallback.valueConverter, fallback.stringConverter, doc, isPublic) { override def defaultValueString: String = s"" override def readFrom(reader: ConfigReader): T = { - reader.get(key).map(valueConverter).getOrElse(fallback.readFrom(reader)) + readString(reader).map(valueConverter).getOrElse(fallback.readFrom(reader)) } - } private[spark] object ConfigEntry { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 97f56a64d600f..5d98a1185f053 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -47,28 +47,16 @@ private[spark] class MapProvider(conf: JMap[String, String]) extends ConfigProvi } /** - * A config provider that only reads Spark config keys, and considers default values for known - * configs when fetching configuration values. + * A config provider that only reads Spark config keys. */ private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends ConfigProvider { - import ConfigEntry._ - override def get(key: String): Option[String] = { if (key.startsWith("spark.")) { - Option(conf.get(key)).orElse(defaultValueString(key)) + Option(conf.get(key)) } else { None } } - private def defaultValueString(key: String): Option[String] = { - findEntry(key) match { - case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString) - case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString) - case e: FallbackConfigEntry[_] => get(e.fallback.key) - case _ => None - } - } - } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala index c62de9bfd8fc3..c1ab22150d024 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala @@ -92,7 +92,7 @@ private[spark] class ConfigReader(conf: ConfigProvider) { require(!usedRefs.contains(ref), s"Circular reference in $input: $ref") val replacement = bindings.get(prefix) - .flatMap(_.get(name)) + .flatMap(getOrDefault(_, name)) .map { v => substitute(v, usedRefs + ref) } .getOrElse(m.matched) Regex.quoteReplacement(replacement) @@ -102,4 +102,20 @@ private[spark] class ConfigReader(conf: ConfigProvider) { } } + /** + * Gets the value of a config from the given `ConfigProvider`. If no value is found for this + * config, and the `ConfigEntry` defines this config has default value, return the default value. + */ + private def getOrDefault(conf: ConfigProvider, key: String): Option[String] = { + conf.get(key).orElse { + ConfigEntry.findEntry(key) match { + case e: ConfigEntryWithDefault[_] => Option(e.defaultValueString) + case e: ConfigEntryWithDefaultString[_] => Option(e.defaultValueString) + case e: ConfigEntryWithDefaultFunction[_] => Option(e.defaultValueString) + case e: FallbackConfigEntry[_] => getOrDefault(conf, e.fallback.key) + case _ => None + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f8139b706a7cc..4ad04b04c312d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -151,9 +151,11 @@ package object config { .createOptional // End blacklist confs - private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = - ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") + private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = + ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") + .withAlternative("spark.scheduler.listenerbus.eventqueue.size") .intConf + .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) // This property sets the root namespace for metrics reporting diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 5533f7b1f2363..801dfaa62306a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.DynamicVariable -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.util.Utils @@ -34,23 +34,14 @@ import org.apache.spark.util.Utils * is stopped when `stop()` is called, and it will drop further events after stopping. */ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { - self => import LiveListenerBus._ // Cap the capacity of the event queue so we get an explicit error (rather than // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private lazy val EVENT_QUEUE_CAPACITY = validateAndGetQueueSize() - private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - - private def validateAndGetQueueSize(): Int = { - val queueSize = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_SIZE) - if (queueSize <= 0) { - throw new SparkException("spark.scheduler.listenerbus.eventqueue.size must be > 0!") - } - queueSize - } + private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( + sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) // Indicate if `start()` is called private val started = new AtomicBoolean(false) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index b72cd8be24206..bf08276dbf971 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -261,4 +261,31 @@ class ConfigEntrySuite extends SparkFunSuite { data = 2 assert(conf.get(iConf) === 2) } + + test("conf entry: alternative keys") { + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("a")) + .withAlternative(testKey("b")) + .withAlternative(testKey("c")) + .intConf.createWithDefault(0) + + // no key is set, return default value. + assert(conf.get(iConf) === 0) + + // the primary key is set, the alternative keys are not set, return the value of primary key. + conf.set(testKey("a"), "1") + assert(conf.get(iConf) === 1) + + // the primary key and alternative keys are all set, return the value of primary key. + conf.set(testKey("b"), "2") + conf.set(testKey("c"), "3") + assert(conf.get(iConf) === 1) + + // the primary key is not set, (some of) the alternative keys are set, return the value of the + // first alternative key that is set. + conf.remove(testKey("a")) + assert(conf.get(iConf) === 2) + conf.remove(testKey("b")) + assert(conf.get(iConf) === 3) + } } From 0fd84b05dc9ac3de240791e2d4200d8bdffbb01a Mon Sep 17 00:00:00 2001 From: 10129659 Date: Fri, 26 May 2017 18:03:23 +0100 Subject: [PATCH 013/133] [SPARK-20835][CORE] It should exit directly when the --total-executor-cores parameter is setted less than 0 when submit a application ## What changes were proposed in this pull request? In my test, the submitted app running with out an error when the --total-executor-cores less than 0 and given the warnings: "2017-05-22 17:19:36,319 WARN org.apache.spark.scheduler.TaskSchedulerImpl: Initial job has not accepted any resources; check your cluster UI to ensure that workers are registered and have sufficient resources"; It should exit directly when the --total-executor-cores parameter is setted less than 0 when submit a application (Please fill in changes proposed in this fix) ## How was this patch tested? Run the ut tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 10129659 Closes #18060 from eatoncys/totalcores. --- .../spark/deploy/SparkSubmitArguments.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0144fd1056bac..5100a17006e24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -27,11 +27,14 @@ import java.util.jar.JarFile import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source +import scala.util.Try import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils + /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -253,6 +256,23 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } + if (driverMemory != null + && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { + SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + } + if (executorMemory != null + && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { + SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + } + if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + } + if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + } + if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + } if (pyFiles != null && !isPython) { SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") } From d935e0a9d9bb3d3c74e9529e161648caa50696b7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 26 May 2017 13:33:23 -0700 Subject: [PATCH 014/133] [SPARK-20844] Remove experimental from Structured Streaming APIs Now that Structured Streaming has been out for several Spark release and has large production use cases, the `Experimental` label is no longer appropriate. I've left `InterfaceStability.Evolving` however, as I think we may make a few changes to the pluggable Source & Sink API in Spark 2.3. Author: Michael Armbrust Closes #18065 from marmbrus/streamingGA. --- .../structured-streaming-programming-guide.md | 4 +- python/pyspark/sql/context.py | 4 +- python/pyspark/sql/dataframe.py | 6 +-- python/pyspark/sql/session.py | 4 +- python/pyspark/sql/streaming.py | 42 +++++++++---------- .../spark/sql/streaming/OutputMode.java | 3 -- .../apache/spark/sql/streaming/Trigger.java | 7 ---- .../scala/org/apache/spark/sql/Dataset.scala | 2 - .../org/apache/spark/sql/ForeachWriter.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 2 - .../org/apache/spark/sql/SparkSession.scala | 2 - .../org/apache/spark/sql/functions.scala | 8 +--- .../sql/streaming/DataStreamReader.scala | 3 +- .../sql/streaming/DataStreamWriter.scala | 4 +- .../spark/sql/streaming/ProcessingTime.scala | 6 +-- .../spark/sql/streaming/StreamingQuery.scala | 4 +- .../streaming/StreamingQueryException.scala | 4 +- .../streaming/StreamingQueryListener.scala | 14 +------ .../sql/streaming/StreamingQueryManager.scala | 6 +-- .../sql/streaming/StreamingQueryStatus.scala | 4 +- .../apache/spark/sql/streaming/progress.scala | 10 +---- 21 files changed, 42 insertions(+), 101 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index bd01be944460b..6a25c9939c264 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Structured Streaming Programming Guide [Experimental] +displayTitle: Structured Streaming Programming Guide title: Structured Streaming Programming Guide --- @@ -10,7 +10,7 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5197a9e004610..426f07cd9410d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -474,7 +474,7 @@ def readStream(self): Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. note:: Experimental. + .. note:: Evolving. :return: :class:`DataStreamReader` @@ -490,7 +490,7 @@ def streams(self): """Returns a :class:`StreamingQueryManager` that allows managing all the :class:`StreamingQuery` StreamingQueries active on `this` context. - .. note:: Experimental. + .. note:: Evolving. """ from pyspark.sql.streaming import StreamingQueryManager return StreamingQueryManager(self._ssql_ctx.streams()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7b67985f2b320..fbe66f18a3613 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -209,7 +209,7 @@ def writeStream(self): Interface for saving the content of the streaming :class:`DataFrame` out into external storage. - .. note:: Experimental. + .. note:: Evolving. :return: :class:`DataStreamWriter` """ @@ -285,7 +285,7 @@ def isStreaming(self): :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming source present. - .. note:: Experimental + .. note:: Evolving """ return self._jdf.isStreaming() @@ -368,7 +368,7 @@ def withWatermark(self, eventTime, delayThreshold): latest record that has been processed in the form of an interval (e.g. "1 minute" or "5 hours"). - .. note:: Experimental + .. note:: Evolving >>> sdf.select('name', sdf.time.cast('timestamp')).withWatermark('time', '10 minutes') DataFrame[name: string, time: timestamp] diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c1bf2bd76fb7c..e3bf0f35ea15e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -586,7 +586,7 @@ def readStream(self): Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. note:: Experimental. + .. note:: Evolving. :return: :class:`DataStreamReader` """ @@ -598,7 +598,7 @@ def streams(self): """Returns a :class:`StreamingQueryManager` that allows managing all the :class:`StreamingQuery` StreamingQueries active on `this` context. - .. note:: Experimental. + .. note:: Evolving. :return: :class:`StreamingQueryManager` """ diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 65b59d480da36..76e8c4f47d8ad 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -41,7 +41,7 @@ class StreamingQuery(object): A handle to a query that is executing continuously in the background as new data arrives. All these methods are thread-safe. - .. note:: Experimental + .. note:: Evolving .. versionadded:: 2.0 """ @@ -197,7 +197,7 @@ def exception(self): class StreamingQueryManager(object): """A class to manage all the :class:`StreamingQuery` StreamingQueries active. - .. note:: Experimental + .. note:: Evolving .. versionadded:: 2.0 """ @@ -283,7 +283,7 @@ class DataStreamReader(OptionUtils): (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` to access this. - .. note:: Experimental. + .. note:: Evolving. .. versionadded:: 2.0 """ @@ -300,7 +300,7 @@ def _df(self, jdf): def format(self, source): """Specifies the input data source format. - .. note:: Experimental. + .. note:: Evolving. :param source: string, name of the data source, e.g. 'json', 'parquet'. @@ -317,7 +317,7 @@ def schema(self, schema): By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. - .. note:: Experimental. + .. note:: Evolving. :param schema: a :class:`pyspark.sql.types.StructType` object @@ -340,7 +340,7 @@ def option(self, key, value): in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. - .. note:: Experimental. + .. note:: Evolving. >>> s = spark.readStream.option("x", 1) """ @@ -356,7 +356,7 @@ def options(self, **options): in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. - .. note:: Experimental. + .. note:: Evolving. >>> s = spark.readStream.options(x="1", y=2) """ @@ -368,7 +368,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads a data stream from a data source and returns it as a :class`DataFrame`. - .. note:: Experimental. + .. note:: Evolving. :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. @@ -411,7 +411,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - .. note:: Experimental. + .. note:: Evolving. :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. @@ -488,7 +488,7 @@ def parquet(self, path): Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \ The default value is specified in ``spark.sql.parquet.mergeSchema``. - .. note:: Experimental. + .. note:: Evolving. >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp()) >>> parquet_sdf.isStreaming @@ -511,7 +511,7 @@ def text(self, path): Each line in the text file is a new row in the resulting DataFrame. - .. note:: Experimental. + .. note:: Evolving. :param paths: string, or list of strings, for input path(s). @@ -539,7 +539,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` is enabled. To avoid going through the entire data once, disable ``inferSchema`` option or specify the schema explicitly using ``schema``. - .. note:: Experimental. + .. note:: Evolving. :param path: string, or list of strings, for input path(s). :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. @@ -637,7 +637,7 @@ class DataStreamWriter(object): (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream` to access this. - .. note:: Experimental. + .. note:: Evolving. .. versionadded:: 2.0 """ @@ -665,7 +665,7 @@ def outputMode(self, outputMode): written to the sink every time there are some updates. If the query doesn't contain aggregations, it will be equivalent to `append` mode. - .. note:: Experimental. + .. note:: Evolving. >>> writer = sdf.writeStream.outputMode('append') """ @@ -678,7 +678,7 @@ def outputMode(self, outputMode): def format(self, source): """Specifies the underlying output data source. - .. note:: Experimental. + .. note:: Evolving. :param source: string, name of the data source, which for now can be 'parquet'. @@ -696,7 +696,7 @@ def option(self, key, value): timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. - .. note:: Experimental. + .. note:: Evolving. """ self._jwrite = self._jwrite.option(key, to_str(value)) return self @@ -710,7 +710,7 @@ def options(self, **options): timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. - .. note:: Experimental. + .. note:: Evolving. """ for k in options: self._jwrite = self._jwrite.option(k, to_str(options[k])) @@ -723,7 +723,7 @@ def partitionBy(self, *cols): If specified, the output is laid out on the file system similar to Hive's partitioning scheme. - .. note:: Experimental. + .. note:: Evolving. :param cols: name of columns @@ -739,7 +739,7 @@ def queryName(self, queryName): :func:`start`. This name must be unique among all the currently active queries in the associated SparkSession. - .. note:: Experimental. + .. note:: Evolving. :param queryName: unique name for the query @@ -756,7 +756,7 @@ def trigger(self, processingTime=None, once=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. - .. note:: Experimental. + .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. @@ -794,7 +794,7 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query If ``format`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. - .. note:: Experimental. + .. note:: Evolving. :param path: the path in a Hadoop supported file system :param format: the format used to save diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 3f7cdb293e0fa..8410abd14fd59 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -22,14 +22,11 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * :: Experimental :: - * * OutputMode is used to what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving public class OutputMode { diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 3e3997fa9bfec..d31790a285687 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -21,22 +21,18 @@ import scala.concurrent.duration.Duration; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; /** - * :: Experimental :: * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving public class Trigger { /** - * :: Experimental :: * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is 0, the query will run as fast as possible. * @@ -47,7 +43,6 @@ public static Trigger ProcessingTime(long intervalMs) { } /** - * :: Experimental :: * (Java-friendly) * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is 0, the query will run as fast as possible. @@ -64,7 +59,6 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { } /** - * :: Experimental :: * (Scala-friendly) * A trigger policy that runs a query periodically based on an interval in processing time. * If `duration` is 0, the query will run as fast as possible. @@ -80,7 +74,6 @@ public static Trigger ProcessingTime(Duration interval) { } /** - * :: Experimental :: * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is effectively 0, the query will run as fast as possible. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f9bd8f3d278ad..0e7415890e216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2786,13 +2786,11 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Interface for saving the content of the streaming Dataset out into external storage. * * @group basic * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 372ec262f5764..86e02e98c01f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the * generated data to external systems. Each partition will use a new deserialized instance, so you * usually should do all the initialization (e.g. opening a connection or initiating a transaction) @@ -66,7 +65,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * }}} * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving abstract class ForeachWriter[T] extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cc2983987eb90..7fde6e9469e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -505,7 +505,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** - * :: Experimental :: * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. * {{{ * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") @@ -514,7 +513,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) * * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def readStream: DataStreamReader = sparkSession.readStream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index a519492ed8f4f..d2bf350711936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -636,7 +636,6 @@ class SparkSession private( def read: DataFrameReader = new DataFrameReader(self) /** - * :: Experimental :: * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. * {{{ * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") @@ -645,7 +644,6 @@ class SparkSession private( * * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def readStream: DataStreamReader = new DataStreamReader(self) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7eea6d8d85b6f..a347991d8490b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -2801,8 +2801,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window( timeColumn: Column, windowDuration: String, @@ -2855,8 +2853,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { window(timeColumn, windowDuration, slideDuration, "0 second") } @@ -2894,8 +2890,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String): Column = { window(timeColumn, windowDuration, windowDuration, "0 second") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 746b2a94f102d..766776230257d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -35,7 +35,6 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0d2611f9bbcce..14e7df672cc58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils @@ -29,13 +29,11 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} /** - * :: Experimental :: * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, * key-value stores, etc). Use `Dataset.writeStream` to access this. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 9ba1fc01cbd30..a033575d3d38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,11 +23,10 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.unsafe.types.CalendarInterval /** - * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, * the query will run as fast as possible. * @@ -49,7 +48,6 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { @@ -57,12 +55,10 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { } /** - * :: Experimental :: * Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 12a1bb1db5779..f2dfbe42260d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.SparkSession /** - * :: Experimental :: * A handle to a query that is executing continuously in the background as new data arrives. * All these methods are thread-safe. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving trait StreamingQuery { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 234a1166a1953..03aeb14de502a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception * that caused the failure. * @param message Message of this exception @@ -29,7 +28,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index c376913516ef7..6aa82b89ede81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.scheduler.SparkListenerEvent /** - * :: Experimental :: * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. * @note The methods are not thread-safe as they may be called from different threads. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving abstract class StreamingQueryListener { @@ -66,32 +64,26 @@ abstract class StreamingQueryListener { /** - * :: Experimental :: * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving object StreamingQueryListener { /** - * :: Experimental :: * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving trait Event extends SparkListenerEvent /** - * :: Experimental :: * Event representing the start of a query * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryStartedEvent private[sql]( val id: UUID, @@ -99,17 +91,14 @@ object StreamingQueryListener { val name: String) extends Event /** - * :: Experimental :: * Event representing any progress updates in a query. * @param progress The query progress updates. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** - * :: Experimental :: * Event representing that termination of a query. * * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. @@ -118,7 +107,6 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryTerminatedEvent private[sql]( val id: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7810d9f6e9642..002c45413b4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -34,12 +34,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Clock, SystemClock, Utils} /** - * :: Experimental :: - * A class to manage all the [[StreamingQuery]] active on a `SparkSession`. + * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 687b1267825fe..a0c9bcc8929eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,10 +22,9 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. @@ -35,7 +34,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryStatus protected[sql]( val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 35fe6b8605fad..fb590e7df996b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,13 +29,11 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@Experimental @InterfaceStability.Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, @@ -54,7 +52,6 @@ class StateOperatorProgress private[sql]( } /** - * :: Experimental :: * Information about progress made in the execution of a [[StreamingQuery]] during * a trigger. Each event relates to processing done for a single trigger of the streaming * query. Events are emitted even when no new data is available to be processed. @@ -80,7 +77,6 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryProgress private[sql]( val id: UUID, @@ -139,7 +135,6 @@ class StreamingQueryProgress private[sql]( } /** - * :: Experimental :: * Information about progress made for a source in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * @@ -152,7 +147,6 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SourceProgress protected[sql]( val description: String, @@ -191,14 +185,12 @@ class SourceProgress protected[sql]( } /** - * :: Experimental :: * Information about progress made for a sink in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { From 473d7552acb19f440a0cb082e6d3cba67579bd5a Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Fri, 26 May 2017 13:41:13 -0700 Subject: [PATCH 015/133] [SPARK-20014] Optimize mergeSpillsWithFileStream method ## What changes were proposed in this pull request? When the individual partition size in a spill is small, mergeSpillsWithTransferTo method does many small disk ios which is really inefficient. One way to improve the performance will be to use mergeSpillsWithFileStream method by turning off transfer to and using buffered file read/write to improve the io throughput. However, the current implementation of mergeSpillsWithFileStream does not do a buffer read/write of the files and in addition to that it unnecessarily flushes the output files for each partitions. ## How was this patch tested? Tested this change by running a job on the cluster and the map stage run time was reduced by around 20%. Author: Sital Kedia Closes #17343 from sitalkedia/upstream_mergeSpillsWithFileStream. --- .../shuffle/sort/UnsafeShuffleWriter.java | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 2fde5c300f072..857ec8a4dadd2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; @@ -98,6 +99,18 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; + private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { + + CloseAndFlushShieldOutputStream(OutputStream outputStream) { + super(outputStream); + } + + @Override + public void flush() { + // do nothing + } + } + public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, @@ -321,11 +334,15 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti } /** - * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, - * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, when - * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in - * order to work around kernel bugs. + * Merges spill files using Java FileStreams. This code path is typically slower than + * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], + * File)}, and it's mostly used in cases where the IO compression codec does not support + * concatenation of compressed data, when encryption is enabled, or when users have + * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * This code path might also be faster in cases where individual partition size in a spill + * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small + * disk ios which is inefficient. In those case, Using large buffers for input and output + * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -339,23 +356,28 @@ private long[] mergeSpillsWithFileStream( assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; - final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + final InputStream[] spillInputStreams = new InputStream[spills.length]; + final OutputStream bos = new BufferedOutputStream( + new FileOutputStream(outputFile), + (int) sparkConf.getSizeAsKb("spark.shuffle.unsafe.file.output.buffer", "32k") * 1024); // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( - new FileOutputStream(outputFile)); + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); + final int inputBufferSizeInBytes = (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { - spillInputStreams[i] = new FileInputStream(spills[i].file); + spillInputStreams[i] = new NioBufferedFileInputStream( + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() calls, so that we can close the higher + // Shield the underlying output stream from close() and flush() calls, so that we can close the higher // level streams to make sure all data is really flushed and internal state is cleaned. - OutputStream partitionOutput = new CloseShieldOutputStream( + OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { From ae33abf71b353c638487948b775e966c7127cd46 Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 26 May 2017 15:01:01 -0700 Subject: [PATCH 016/133] [SPARK-20694][DOCS][SQL] Document DataFrameWriter partitionBy, bucketBy and sortBy in SQL guide ## What changes were proposed in this pull request? - Add Scala, Python and Java examples for `partitionBy`, `sortBy` and `bucketBy`. - Add _Bucketing, Sorting and Partitioning_ section to SQL Programming Guide - Remove bucketing from Unsupported Hive Functionalities. ## How was this patch tested? Manual tests, docs build. Author: zero323 Closes #17938 from zero323/DOCS-BUCKETING-AND-PARTITIONING. --- docs/sql-programming-guide.md | 108 ++++++++++++++++++ .../sql/JavaSQLDataSourceExample.java | 16 +++ examples/src/main/python/sql/datasource.py | 20 ++++ .../examples/sql/SQLDataSourceExample.scala | 16 +++ 4 files changed, 160 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index adb12d2489a57..314ff6ef80d29 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -581,6 +581,114 @@ Starting from Spark 2.1, persistent datasource tables have per-partition metadat Note that partition information is not gathered by default when creating external datasource tables (those with a `path` option). To sync the partition information in the metastore, you can invoke `MSCK REPAIR TABLE`. +### Bucketing, Sorting and Partitioning + +For file-based data source, it is also possible to bucket and sort or partition the output. +Bucketing and sorting are applicable only to persistent tables: + +
+ +
+{% include_example write_sorting_and_bucketing scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example write_sorting_and_bucketing java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example write_sorting_and_bucketing python/sql/datasource.py %} +
+ +
+ +{% highlight sql %} + +CREATE TABLE users_bucketed_by_name( + name STRING, + favorite_color STRING, + favorite_numbers array +) USING parquet +CLUSTERED BY(name) INTO 42 BUCKETS; + +{% endhighlight %} + +
+ +
+ +while partitioning can be used with both `save` and `saveAsTable` when using the Dataset APIs. + + +
+ +
+{% include_example write_partitioning scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example write_partitioning java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example write_partitioning python/sql/datasource.py %} +
+ +
+ +{% highlight sql %} + +CREATE TABLE users_by_favorite_color( + name STRING, + favorite_color STRING, + favorite_numbers array +) USING csv PARTITIONED BY(favorite_color); + +{% endhighlight %} + +
+ +
+ +It is possible to use both partitioning and bucketing for a single table: + +
+ +
+{% include_example write_partition_and_bucket scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example write_partition_and_bucket java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example write_partition_and_bucket python/sql/datasource.py %} +
+ +
+ +{% highlight sql %} + +CREATE TABLE users_bucketed_and_partitioned( + name STRING, + favorite_color STRING, + favorite_numbers array +) USING parquet +PARTITIONED BY (favorite_color) +CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; + +{% endhighlight %} + +
+ +
+ +`partitionBy` creates a directory structure as described in the [Partition Discovery](#partition-discovery) section. +Thus, it has limited applicability to columns with high cardinality. In contrast + `bucketBy` distributes +data across a fixed number of buckets and can be used when a number of unique values is unbounded. + ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index b66abaed66000..706856b5215e4 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -120,6 +120,22 @@ private static void runBasicDataSourceExample(SparkSession spark) { Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); // $example off:direct_sql$ + // $example on:write_sorting_and_bucketing$ + peopleDF.write().bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed"); + // $example off:write_sorting_and_bucketing$ + // $example on:write_partitioning$ + usersDF.write().partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet"); + // $example off:write_partitioning$ + // $example on:write_partition_and_bucket$ + peopleDF + .write() + .partitionBy("favorite_color") + .bucketBy(42, "name") + .saveAsTable("people_partitioned_bucketed"); + // $example off:write_partition_and_bucket$ + + spark.sql("DROP TABLE IF EXISTS people_bucketed"); + spark.sql("DROP TABLE IF EXISTS people_partitioned_bucketed"); } private static void runBasicParquetExample(SparkSession spark) { diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index e4abb0933345d..8777cca66bfe9 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -35,15 +35,35 @@ def basic_datasource_example(spark): df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") # $example off:generic_load_save_functions$ + # $example on:write_partitioning$ + df.write.partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet") + # $example off:write_partitioning$ + + # $example on:write_partition_and_bucket$ + df = spark.read.parquet("examples/src/main/resources/users.parquet") + (df + .write + .partitionBy("favorite_color") + .bucketBy(42, "name") + .saveAsTable("people_partitioned_bucketed")) + # $example off:write_partition_and_bucket$ + # $example on:manual_load_options$ df = spark.read.load("examples/src/main/resources/people.json", format="json") df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ + # $example on:write_sorting_and_bucketing$ + df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") + # $example off:write_sorting_and_bucketing$ + # $example on:direct_sql$ df = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ + spark.sql("DROP TABLE IF EXISTS people_bucketed") + spark.sql("DROP TABLE IF EXISTS people_partitioned_bucketed") + def parquet_example(spark): # $example on:basic_parquet_example$ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index ad74da72bd5e6..6ff03bdb22129 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -52,6 +52,22 @@ object SQLDataSourceExample { // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") // $example off:direct_sql$ + // $example on:write_sorting_and_bucketing$ + peopleDF.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") + // $example off:write_sorting_and_bucketing$ + // $example on:write_partitioning$ + usersDF.write.partitionBy("favorite_color").format("parquet").save("namesPartByColor.parquet") + // $example off:write_partitioning$ + // $example on:write_partition_and_bucket$ + peopleDF + .write + .partitionBy("favorite_color") + .bucketBy(42, "name") + .saveAsTable("people_partitioned_bucketed") + // $example off:write_partition_and_bucket$ + + spark.sql("DROP TABLE IF EXISTS people_bucketed") + spark.sql("DROP TABLE IF EXISTS people_partitioned_bucketed") } private def runBasicParquetExample(spark: SparkSession): Unit = { From c491e2ed90afc980d197d54cb965213bdd192b4b Mon Sep 17 00:00:00 2001 From: setjet Date: Fri, 26 May 2017 15:07:28 -0700 Subject: [PATCH 017/133] [SPARK-20873][SQL] Improve the error message for unsupported Column Type ## What changes were proposed in this pull request? Upon encountering an invalid columntype, the column type object is printed, rather than the type. This change improves this by outputting its name. ## How was this patch tested? Added a simple unit test to verify the contents of the raised exception Author: setjet Closes #18097 from setjet/spark-20873. --- .../spark/sql/execution/columnar/ColumnType.scala | 2 +- .../sql/execution/columnar/ColumnTypeSuite.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 703bde25316df..5cfb003e4f150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -684,7 +684,7 @@ private[columnar] object ColumnType { case struct: StructType => STRUCT(struct) case udt: UserDefinedType[_] => apply(udt.sqlType) case other => - throw new Exception(s"Unsupported type: $other") + throw new Exception(s"Unsupported type: ${other.simpleString}") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 5f2a3aaff634c..ff05049551dc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -144,4 +144,18 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { ColumnType(DecimalType(19, 0)) } } + + test("show type name in type mismatch error") { + val invalidType = new DataType { + override def defaultSize: Int = 1 + override private[spark] def asNullable: DataType = this + override def typeName: String = "invalid type name" + } + + val message = intercept[java.lang.Exception] { + ColumnType(invalidType) + }.getMessage + + assert(message.contains("Unsupported type: invalid type name")) + } } From 4af37812915763ac3bfd91a600a7f00a4b84d29a Mon Sep 17 00:00:00 2001 From: Yu Peng Date: Fri, 26 May 2017 16:28:36 -0700 Subject: [PATCH 018/133] [SPARK-10643][CORE] Make spark-submit download remote files to local in client mode ## What changes were proposed in this pull request? This PR makes spark-submit script download remote files to local file system for local/standalone client mode. ## How was this patch tested? - Unit tests - Manual tests by adding s3a jar and testing against file on s3. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Yu Peng Closes #18078 from loneknightpy/download-jar-in-spark-submit. --- .../org/apache/spark/deploy/SparkSubmit.scala | 48 +++++++++- .../spark/deploy/SparkSubmitSuite.scala | 95 ++++++++++++++++++- 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 77005aa9040b5..c60a2a1706d5a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.nio.file.Files import java.security.PrivilegedExceptionAction import java.text.ParseException @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } + // In client mode, download remote files. + if (deployMode == CLIENT) { + val hadoopConf = new HadoopConfiguration() + args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull + args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull + args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local. // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils { .mkString(",") if (merged == "") null else merged } + + /** + * Download a list of remote files to temp local files. If the file is local, the original file + * will be returned. + * @param fileList A comma separated file list. + * @return A comma separated local files list. + */ + private[deploy] def downloadFileList( + fileList: String, + hadoopConf: HadoopConfiguration): String = { + require(fileList != null, "fileList cannot be null.") + fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + } + + /** + * Download a file from the remote to a local temporary directory. If the input path points to + * a local path, returns it with no operation. + */ + private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + require(path != null, "path cannot be null.") + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "file" | "local" => + path + + case _ => + val fs = FileSystem.get(uri, hadoopConf) + val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") + // scalastyle:on println + fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) + Utils.resolveURI(tmpFile.getAbsolutePath).toString + } + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a43839a8815f9..6e9721c45931a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts @@ -535,7 +538,7 @@ class SparkSubmitSuite test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars - val files = "hdfs:/file1,file2" // --files + val files = "local:/file1,file2" // --files val archives = "file:/archive1,archive2" // --archives val pyFiles = "py-file1,py-file2" // --py-files @@ -587,7 +590,7 @@ class SparkSubmitSuite test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars - val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles @@ -705,6 +708,87 @@ class SparkSubmitSuite } // scalastyle:on println + private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { + if (sourcePath == outputPath) { + return + } + + val sourceUri = new URI(sourcePath) + val outputUri = new URI(outputPath) + assert(outputUri.getScheme === "file") + + // The path and filename are preserved. + assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(FileUtils.readFileToString(new File(outputUri.getPath)) === + FileUtils.readFileToString(new File(sourceUri.getPath))) + } + + private def deleteTempOutputFile(outputPath: String): Unit = { + val outputFile = new File(new URI(outputPath).getPath) + if (outputFile.exists) { + outputFile.delete() + } + } + + test("downloadFile - invalid url") { + intercept[IOException] { + SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + } + } + + test("downloadFile - file doesn't exist") { + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + intercept[FileNotFoundException] { + SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + } + } + + test("downloadFile does not download local file") { + // empty path is considered as local file. + assert(SparkSubmit.downloadFile("", new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + } + + test("download one file to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePath = s"s3a://${jarFile.getAbsolutePath}" + val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + + test("download list of files to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") + val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + + assert(outputPaths.length === sourcePaths.length) + sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -807,3 +891,10 @@ object UserClasspathFirstTest { } } } + +class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { + override def copyToLocalFile(src: Path, dst: Path): Unit = { + // Ignore the scheme for testing. + super.copyToLocalFile(new Path(src.toUri.getPath), dst) + } +} From 1d62f8aca82601506c44b6fd852f4faf3602d7e2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 27 May 2017 10:57:43 +0800 Subject: [PATCH 019/133] [SPARK-19659][CORE][FOLLOW-UP] Fetch big blocks to disk when shuffle-read ## What changes were proposed in this pull request? This PR includes some minor improvement for the comments and tests in https://github.com/apache/spark/pull/16989 ## How was this patch tested? N/A Author: Wenchen Fan Closes #18117 from cloud-fan/follow. --- .../storage/ShuffleBlockFetcherIterator.scala | 9 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 50 ++++++++++--------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ee35060926555..bded3a1e4eb54 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -214,11 +214,12 @@ final class ShuffleBlockFetcherIterator( } } - // Shuffle remote blocks to disk when the request is too large. - // TODO: Encryption and compression should be considered. + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - val shuffleFiles = blockIds.map { - bId => blockManager.diskBlockManager.createTempLocalBlock()._2 + val shuffleFiles = blockIds.map { _ => + blockManager.diskBlockManager.createTempLocalBlock()._2 }.toArray shuffleFilesSet ++= shuffleFiles shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 1f813a909fb8b..559b3faab8fd2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -420,9 +421,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(localBmId).when(blockManager).blockManagerId val diskBlockManager = mock(classOf[DiskBlockManager]) + val tmpDir = Utils.createTempDir() doReturn{ - var blockId = new TempLocalBlockId(UUID.randomUUID()) - (blockId, new File(blockId.name)) + val blockId = TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(tmpDir, blockId.name)) }.when(diskBlockManager).createTempLocalBlock() doReturn(diskBlockManager).when(blockManager).diskBlockManager @@ -443,34 +445,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) + def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks + // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + maxBytesInFlight = Int.MaxValue, + maxReqsInFlight = Int.MaxValue, + maxReqSizeShuffleToMem = 200, + detectCorrupt = true) + } + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) - // Set maxReqSizeShuffleToMem to be 200. - val iterator1 = new ShuffleBlockFetcherIterator( - TaskContext.empty(), - transfer, - blockManager, - blocksByAddress1, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 200, - true) + fetchShuffleBlock(blocksByAddress1) + // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch + // shuffle block to disk. assert(shuffleFiles === null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) - // Set maxReqSizeShuffleToMem to be 200. - val iterator2 = new ShuffleBlockFetcherIterator( - TaskContext.empty(), - transfer, - blockManager, - blocksByAddress2, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 200, - true) + fetchShuffleBlock(blocksByAddress2) + // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch + // shuffle block to disk. assert(shuffleFiles != null) } } From a0f8a072e33842f19e53fd28d7578444d1d26cb3 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 26 May 2017 20:59:14 -0700 Subject: [PATCH 020/133] [SPARK-20748][SQL] Add built-in SQL function CH[A]R. ## What changes were proposed in this pull request? Add built-in SQL function `CH[A]R`: For `CHR(bigint|double n)`, returns the ASCII character having the binary equivalent to `n`. If n is larger than 256 the result is equivalent to CHR(n % 256) ## How was this patch tested? unit tests Author: Yuming Wang Closes #18019 from wangyum/SPARK-20748. --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/stringExpressions.scala | 45 +++++++++++++++++++ .../expressions/StringExpressionsSuite.scala | 13 ++++++ 3 files changed, 60 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7521a7e12432c..a4c7f7a8de223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -276,6 +276,8 @@ object FunctionRegistry { // string functions expression[Ascii]("ascii"), + expression[Chr]("char"), + expression[Chr]("chr"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5598a146997ca..aba2f5f81f831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1267,6 +1267,51 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } +/** + * Returns the ASCII character having the binary equivalent to n. + * If n is larger than 256 the result is equivalent to chr(n % 256) + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the ASCII character having the binary equivalent to `expr`. If n is larger than 256 the result is equivalent to chr(n % 256)", + extended = """ + Examples: + > SELECT _FUNC_(65); + A + """) +// scalastyle:on line.size.limit +case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(LongType) + + protected override def nullSafeEval(lon: Any): Any = { + val longVal = lon.asInstanceOf[Long] + if (longVal < 0) { + UTF8String.EMPTY_UTF8 + } else if ((longVal & 0xFF) == 0) { + UTF8String.fromString(Character.MIN_VALUE.toString) + } else { + UTF8String.fromString((longVal & 0xFF).toChar.toString) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, lon => { + s""" + if ($lon < 0) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else if (($lon & 0xFF) == 0) { + ${ev.value} = UTF8String.fromString(String.valueOf(Character.MIN_VALUE)); + } else { + char c = (char)($lon & 0xFF); + ${ev.value} = UTF8String.fromString(String.valueOf(c)); + } + """ + }) + } +} + /** * Converts the argument from binary to a base 64 string. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 26978a0482fc7..9ae438d568a90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -263,6 +263,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef")) } + test("string for ascii") { + val a = 'a.long.at(0) + checkEvaluation(Chr(Literal(48L)), "0", create_row("abdef")) + checkEvaluation(Chr(a), "a", create_row(97L)) + checkEvaluation(Chr(a), "a", create_row(97L + 256L)) + checkEvaluation(Chr(a), "", create_row(-9L)) + checkEvaluation(Chr(a), Character.MIN_VALUE.toString, create_row(0L)) + checkEvaluation(Chr(a), Character.MIN_VALUE.toString, create_row(256L)) + checkEvaluation(Chr(a), null, create_row(null)) + checkEvaluation(Chr(a), 149.toChar.toString, create_row(149L)) + checkEvaluation(Chr(Literal.create(null, LongType)), null, create_row("abdef")) + } + test("base64/unbase64 for string") { val a = 'a.string.at(0) val b = 'b.binary.at(0) From 6c1dbd6fc8d49acf7c1c902d2ebf89ed5e788a4e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 May 2017 22:25:38 -0700 Subject: [PATCH 021/133] [SPARK-20843][CORE] Add a config to set driver terminate timeout ## What changes were proposed in this pull request? Add a `worker` configuration to set how long to wait before forcibly killing driver. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18126 from zsxwing/SPARK-20843. --- .../scala/org/apache/spark/deploy/worker/DriverRunner.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e878c10183f61..58a181128eb4d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -57,7 +57,8 @@ private[deploy] class DriverRunner( @volatile private[worker] var finalException: Option[Exception] = None // Timeout to wait for when trying to terminate a driver. - private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000 + private val DRIVER_TERMINATE_TIMEOUT_MS = + conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s") // Decoupled for testing def setClock(_clock: Clock): Unit = { From 8faffc41679cf545c0aea96b05d84f23da1b5eda Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Sat, 27 May 2017 13:26:01 +0100 Subject: [PATCH 022/133] [SPARK-20875] Spark should print the log when the directory has been deleted [https://issues.apache.org/jira/browse/SPARK-20875](https://issues.apache.org/jira/browse/SPARK-20875) When the "deleteRecursively" method is invoked,spark doesn't print any log if the path was deleted.For example,spark only print "Removing directory" when the worker began cleaning spark.work.dir,but didn't print any log about "the path has been delete".So, I can't judge whether the path was deleted form the worker's logfile,If there is any accidents about Linux. Author: liuzhaokun Closes #18102 from liu-zhaokun/master_log. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ad39c74a0e232..bbb7999e2a144 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1026,7 +1026,9 @@ private[spark] object Utils extends Logging { ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { - if (!file.delete()) { + if (file.delete()) { + logTrace(s"${file.getAbsolutePath} has been deleted") + } else { // Delete can also fail if the file simply did not exist if (file.exists()) { throw new IOException("Failed to delete: " + file.getAbsolutePath) From 08ede46b897b7e52cfe8231ffc21d9515122cf49 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 27 May 2017 16:16:51 -0700 Subject: [PATCH 023/133] [SPARK-20897][SQL] cached self-join should not fail ## What changes were proposed in this pull request? The failed test case is, we have a `SortMergeJoinExec` for a self-join, which means we have a `ReusedExchange` node in the query plan. It works fine without caching, but throws an exception in `SortMergeJoinExec.outputPartitioning` if we cache it. The root cause is, `ReusedExchange` doesn't propagate the output partitioning from its child, so in `SortMergeJoinExec.outputPartitioning` we create `PartitioningCollection` with a hash partitioning and an unknown partitioning, and fail. This bug is mostly fine, because inserting the `ReusedExchange` is the last step to prepare the physical plan, we won't call `SortMergeJoinExec.outputPartitioning` anymore after this. However, if the dataframe is cached, the physical plan of it becomes `InMemoryTableScanExec`, which contains another physical plan representing the cached query, and it has gone through the entire planning phase and may have `ReusedExchange`. Then the planner call `InMemoryTableScanExec.outputPartitioning`, which then calls `SortMergeJoinExec.outputPartitioning` and trigger this bug. ## How was this patch tested? a new regression test Author: Wenchen Fan Closes #18121 from cloud-fan/bug. --- .../sql/execution/exchange/Exchange.scala | 21 ++++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index d993ea6c6cef9..4b52f3e4c49b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { child.executeBroadcast() } + + // `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(child.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2f52192b54030..9f691cb10f139 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1855,4 +1855,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) df.filter(filter).count } + + test("SPARK-20897: cached self-join should not fail") { + // force to plan sort merge join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df = Seq(1 -> "a").toDF("i", "j") + val df1 = df.as("t1") + val df2 = df.as("t2") + assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) + } + } } From 3969a8078eef63c37a5ba52f9eb4b4666b67d78d Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 27 May 2017 16:23:45 -0700 Subject: [PATCH 024/133] [SPARK-20876][SQL] If the input parameter is float type for ceil or floor,the result is not we expected ## What changes were proposed in this pull request? spark-sql>SELECT ceil(cast(12345.1233 as float)); spark-sql>12345 For this case, the result we expected is `12346` spark-sql>SELECT floor(cast(-12345.1233 as float)); spark-sql>-12345 For this case, the result we expected is `-12346` Because in `Ceil` or `Floor`, `inputTypes` has no FloatType, so it is converted to LongType. ## How was this patch tested? After the modification: spark-sql>SELECT ceil(cast(12345.1233 as float)); spark-sql>12346 spark-sql>SELECT floor(cast(-12345.1233 as float)); spark-sql>-12346 Author: liuxian Closes #18103 from 10110346/wip-lx-0525-1. --- .../expressions/mathExpressions.scala | 14 ++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../expressions/MathExpressionsSuite.scala | 20 +++++++ .../resources/sql-tests/inputs/operators.sql | 3 -- .../sql-tests/results/operators.sql.out | 52 +++++-------------- 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 754b5c4f74e6a..7b64568c69659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3b4289767ad0c..7eccca2e85649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val plan = testRelation2.select('c).orderBy(Floor('a).asc) val expected = testRelation2.select(c, a) - .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) + .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 8ed7a82b943b6..6af0cde73538b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Ceil(doublePi), 4L, EmptyRow) + checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) + checkEvaluation(Ceil(longLit), longLit, EmptyRow) + checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) + checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) + checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) } test("floor") { @@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Floor(doublePi), 3L, EmptyRow) + checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) + checkEvaluation(Floor(longLit), longLit, EmptyRow) + checkEvaluation(Floor(-doublePi), -4L, EmptyRow) + checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) + checkEvaluation(Floor(-longLit), -longLit, EmptyRow) } test("factorial") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index f7167472b05c6..7e3b86b76a34a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -64,12 +64,9 @@ select cot(-1); select ceiling(0); select ceiling(1); select ceil(1234567890123456); -select ceil(12345678901234567); select ceiling(1234567890123456); -select ceiling(12345678901234567); -- floor select floor(0); select floor(1); select floor(1234567890123456); -select floor(12345678901234567); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fe52005aa91da..28cfb744193ec 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 45 -- !query 0 @@ -321,7 +321,7 @@ struct -- !query 38 select ceiling(0) -- !query 38 schema -struct +struct -- !query 38 output 0 @@ -329,7 +329,7 @@ struct -- !query 39 select ceiling(1) -- !query 39 schema -struct +struct -- !query 39 output 1 @@ -343,56 +343,32 @@ struct -- !query 41 -select ceil(12345678901234567) +select ceiling(1234567890123456) -- !query 41 schema -struct +struct -- !query 41 output -12345678901234567 +1234567890123456 -- !query 42 -select ceiling(1234567890123456) +select floor(0) -- !query 42 schema -struct +struct -- !query 42 output -1234567890123456 +0 -- !query 43 -select ceiling(12345678901234567) +select floor(1) -- !query 43 schema -struct +struct -- !query 43 output -12345678901234567 - - --- !query 44 -select floor(0) --- !query 44 schema -struct --- !query 44 output -0 - - --- !query 45 -select floor(1) --- !query 45 schema -struct --- !query 45 output 1 --- !query 46 +-- !query 44 select floor(1234567890123456) --- !query 46 schema +-- !query 44 schema struct --- !query 46 output +-- !query 44 output 1234567890123456 - - --- !query 47 -select floor(12345678901234567) --- !query 47 schema -struct --- !query 47 output -12345678901234567 From 06c155c90dc784b07002f33d98dcfe9be1e38002 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 27 May 2017 21:32:18 -0700 Subject: [PATCH 025/133] [SPARK-20908][SQL] Cache Manager: Hint should be ignored in plan matching ### What changes were proposed in this pull request? In Cache manager, the plan matching should ignore Hint. ```Scala val df1 = spark.range(10).join(broadcast(spark.range(10))) df1.cache() spark.range(10).join(spark.range(10)).explain() ``` The output plan of the above query shows that the second query is not using the cached data of the first query. ``` BroadcastNestedLoopJoin BuildRight, Inner :- *Range (0, 10, step=1, splits=2) +- BroadcastExchange IdentityBroadcastMode +- *Range (0, 10, step=1, splits=2) ``` After the fix, the plan becomes ``` InMemoryTableScan [id#20L, id#23L] +- InMemoryRelation [id#20L, id#23L], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- BroadcastNestedLoopJoin BuildRight, Inner :- *Range (0, 10, step=1, splits=2) +- BroadcastExchange IdentityBroadcastMode +- *Range (0, 10, step=1, splits=2) ``` ### How was this patch tested? Added a test. Author: Xiao Li Closes #18131 from gatorsmile/HintCache. --- .../apache/spark/sql/catalyst/plans/logical/hints.scala | 2 ++ .../apache/spark/sql/catalyst/plans/SameResultSuite.scala | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index b96d7bc9cfdb6..5fe6d2d8da064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -40,6 +40,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override def output: Seq[Attribute] = child.output + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def computeStats(conf: SQLConf): Statistics = { val stats = child.stats(conf) stats.copy(hints = hints) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 467f76193cfc5..7c8ed78a49116 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -66,4 +66,10 @@ class SameResultSuite extends SparkFunSuite { assertSameResult(Union(Seq(testRelation, testRelation2)), Union(Seq(testRelation2, testRelation))) } + + test("hint") { + val df1 = testRelation.join(ResolvedHint(testRelation)) + val df2 = testRelation.join(testRelation) + assertSameResult(df1, df2) + } } From 24d34281d78ff32bff584e9415ac592c0e7cdf2d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 28 May 2017 13:23:18 -0700 Subject: [PATCH 026/133] [SPARK-20841][SQL] Support table column aliases in FROM clause ## What changes were proposed in this pull request? This pr added parsing rules to support table column aliases in FROM clause. ## How was this patch tested? Added tests in `PlanParserSuite`, `SQLQueryTestSuite`, and `PlanParserSuite`. Author: Takeshi Yamamuro Closes #18079 from maropu/SPARK-20841. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../sql/catalyst/analysis/Analyzer.scala | 20 +++++- .../ResolveTableValuedFunctions.scala | 5 +- .../sql/catalyst/analysis/unresolved.scala | 20 +++++- .../sql/catalyst/parser/AstBuilder.scala | 16 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 19 +++++- .../sql/catalyst/analysis/AnalysisTest.scala | 1 + .../sql/catalyst/parser/PlanParserSuite.scala | 11 +++- .../parser/TableIdentifierParserSuite.scala | 2 +- .../sql-tests/inputs/table-aliases.sql | 17 +++++ .../sql-tests/results/table-aliases.sql.out | 63 +++++++++++++++++++ .../benchmark/TPCDSQueryBenchmark.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 13 files changed, 165 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index dc11e536efc45..547013c23fd78 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,7 +472,7 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName + : tableIdentifier sample? tableAlias #tableName | '(' queryNoWith ')' sample? (AS? strictIdentifier) #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 @@ -711,7 +711,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 85cf8ddbaacf4..8818404094eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -593,7 +593,25 @@ class Analyzer( def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match { case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) => val defaultDatabase = AnalysisContext.get.defaultDatabase - val relation = lookupTableFromCatalog(u, defaultDatabase) + val foundRelation = lookupTableFromCatalog(u, defaultDatabase) + + // Add `Project` to rename output column names if a query has alias names: + // e.g., SELECT col1, col2 FROM testData AS t(col1, col2) + val relation = if (u.outputColumnNames.nonEmpty) { + val outputAttrs = foundRelation.output + // Checks if the number of the aliases equals to the number of columns in the table. + if (u.outputColumnNames.size != outputAttrs.size) { + u.failAnalysis(s"Number of column aliases does not match number of columns. " + + s"Table name: ${u.tableName}; number of column aliases: " + + s"${u.outputColumnNames.size}; number of columns: ${outputAttrs.size}.") + } + val aliases = outputAttrs.zip(u.outputColumnNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, foundRelation) + } else { + foundRelation + } resolveRelation(relation) // The view's child should be a logical plan parsed from the `desc.viewText`, the variable // `viewText` should be defined, or else we throw an error on the generation of the View diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 40675359bec47..a214e59302cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -131,8 +131,9 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { val outputAttrs = resolvedFunc.output // Checks if the number of the aliases is equal to expected one if (u.outputNames.size != outputAttrs.size) { - u.failAnalysis(s"expected ${outputAttrs.size} columns but " + - s"found ${u.outputNames.size} columns") + u.failAnalysis(s"Number of given aliases does not match number of output columns. " + + s"Function name: ${u.functionName}; number of aliases: " + + s"${u.outputNames.size}; number of output columns: ${outputAttrs.size}.") } val aliases = outputAttrs.zip(u.outputNames).map { case (attr, name) => Alias(attr, name)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 51bef6e20b9fa..42b9641bef276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -36,8 +36,21 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str /** * Holds the name of a relation that has yet to be looked up in a catalog. + * We could add alias names for columns in a relation: + * {{{ + * // Assign alias names + * SELECT col1, col2 FROM testData AS t(col1, col2); + * }}} + * + * @param tableIdentifier table name + * @param outputColumnNames alias names of columns. If these names given, an analyzer adds + * [[Project]] to rename the columns. */ -case class UnresolvedRelation(tableIdentifier: TableIdentifier) extends LeafNode { +case class UnresolvedRelation( + tableIdentifier: TableIdentifier, + outputColumnNames: Seq[String] = Seq.empty) + extends LeafNode { + /** Returns a `.` separated name for this relation. */ def tableName: String = tableIdentifier.unquotedString @@ -71,6 +84,11 @@ case class UnresolvedInlineTable( * // Assign alias names * select t.a from range(10) t(a); * }}} + * + * @param functionName name of this table-value function + * @param functionArgs list of function arguments + * @param outputNames alias names of function output columns. If these names given, an analyzer + * adds [[Project]] to rename the output columns. */ case class UnresolvedTableValuedFunction( functionName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7d2e3a6fe7580..5f34d0777d5a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -676,12 +676,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create an aliased table reference. This is typically used in FROM clauses. */ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { - val table = UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) - - val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { - case Some(strictIdentifier) => - SubqueryAlias(strictIdentifier, table) - case _ => table + val tableId = visitTableIdentifier(ctx.tableIdentifier) + val table = if (ctx.tableAlias.identifierList != null) { + UnresolvedRelation(tableId, visitIdentifierList(ctx.tableAlias.identifierList)) + } else { + UnresolvedRelation(tableId) + } + val tableWithAlias = if (ctx.tableAlias.strictIdentifier != null) { + SubqueryAlias(ctx.tableAlias.strictIdentifier.getText, table) + } else { + table } tableWithAlias.optionalMap(ctx.sample)(withSample) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7eccca2e85649..5393786891e07 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -465,6 +465,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) assertAnalysisError( rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), - Seq("expected 1 columns but found 2 columns")) + Seq("Number of given aliases does not match number of output columns. " + + "Function name: range; number of aliases: 2; number of output columns: 1.")) + } + + test("SPARK-20841 Support table column aliases in FROM clause") { + def tableColumnsWithAliases(outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedRelation(TableIdentifier("TaBlE3"), outputNames)) + .select(star()) + } + assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) + assertAnalysisError( + tableColumnsWithAliases("col1" :: Nil), + Seq("Number of column aliases does not match number of columns. Table name: TaBlE3; " + + "number of column aliases: 1; number of columns: 4.")) + assertAnalysisError( + tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), + Seq("Number of column aliases does not match number of columns. Table name: TaBlE3; " + + "number of column aliases: 5; number of columns: 4.")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 82015b1e0671c..afc7ce4195a8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,6 +35,7 @@ trait AnalysisTest extends PlanTest { val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) + catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 134e761460881..7a5357eef8f94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -493,6 +493,13 @@ class PlanParserSuite extends PlanTest { .select(star())) } + test("SPARK-20841 Support table column aliases in FROM clause") { + assertEqual( + "SELECT * FROM testData AS t(col1, col2)", + SubqueryAlias("t", UnresolvedRelation(TableIdentifier("testData"), Seq("col1", "col2"))) + .select(star())) + } + test("inline table") { assertEqual("values 1, 2, 3, 4", UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 170c469197e73..f33abc5b2e049 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -49,7 +49,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", - "true", "truncate", "update", "user", "using", "values", "with", "regexp", "rlike", + "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", "int", "smallint", "timestamp", "at") diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql new file mode 100644 index 0000000000000..c90a9c7f85587 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql @@ -0,0 +1,17 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1) AS testData(a, b); + +-- Table column aliases in FROM clause +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 1; + +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 2; + +SELECT col1 AS k, SUM(col2) FROM testData AS t(col1, col2) GROUP BY k; + +-- Aliasing the wrong number of columns in the FROM clause +SELECT * FROM testData AS t(col1, col2, col3); + +SELECT * FROM testData AS t(col1); + +-- Check alias duplication +SELECT a AS col1, b AS col2 FROM testData AS t(c, d); diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out new file mode 100644 index 0000000000000..c318018dced29 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -0,0 +1,63 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1) AS testData(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 1 +-- !query 1 schema +struct +-- !query 1 output +1 1 +1 2 + + +-- !query 2 +SELECT * FROM testData AS t(col1, col2) WHERE col1 = 2 +-- !query 2 schema +struct +-- !query 2 output +2 1 + + +-- !query 3 +SELECT col1 AS k, SUM(col2) FROM testData AS t(col1, col2) GROUP BY k +-- !query 3 schema +struct +-- !query 3 output +1 3 +2 1 + + +-- !query 4 +SELECT * FROM testData AS t(col1, col2, col3) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Number of column aliases does not match number of columns. Table name: testData; number of column aliases: 3; number of columns: 2.; line 1 pos 14 + + +-- !query 5 +SELECT * FROM testData AS t(col1) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Number of column aliases does not match number of columns. Table name: testData; number of column aliases: 1; number of columns: 2.; line 1 pos 14 + + +-- !query 6 +SELECT a AS col1, b AS col2 FROM testData AS t(c, d) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '`a`' given input columns: [c, d]; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index a6249ce021400..6a5b74b01df80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -74,13 +74,13 @@ object TPCDSQueryBenchmark { // per-row processing time for those cases. val queryRelations = scala.collection.mutable.HashSet[String]() spark.sql(queryString).queryExecution.logical.map { - case ur @ UnresolvedRelation(t: TableIdentifier) => + case UnresolvedRelation(t: TableIdentifier, _) => queryRelations.add(t.table) case lp: LogicalPlan => lp.expressions.foreach { _ foreach { case subquery: SubqueryExpression => subquery.plan.foreach { - case ur @ UnresolvedRelation(t: TableIdentifier) => + case UnresolvedRelation(t: TableIdentifier, _) => queryRelations.add(t.table) case _ => } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ee9ac21a738dc..e1534c797d55b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -544,7 +544,7 @@ private[hive] class TestHiveQueryExecution( // Make sure any test tables referenced are loaded. val referencedTables = describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent) => tableIdent.table } + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) From 9d0db5a7f8f0168c3cd95994b10511f8d6a241c3 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Sun, 28 May 2017 13:32:45 -0700 Subject: [PATCH 027/133] [SPARK-20881][SQL] Clearly document the mechanism to choose between two sources of statistics ## What changes were proposed in this pull request? Now, we have two sources of statistics, i.e. Spark's stats and Hive's stats. Spark's stats is generated by running "analyze" command in Spark. Once it's available, we respect this stats over Hive's. This pr is to clearly document in related code the mechanism to choose between these two sources of stats. ## How was this patch tested? Not related. Author: Zhenhua Wang Closes #18105 from wzhfy/cboSwitchStats. --- .../scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala | 4 +++- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ba48facff2933..918459fe7c246 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -681,9 +681,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // construct Spark's statistics from information in Hive metastore + // Restore Spark's statistics from information in Metastore. val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX)) + // Currently we have two sources of statistics: one from Hive and the other from Spark. + // In our design, if Spark's statistics is available, we respect it over Hive's statistics. if (statsProps.nonEmpty) { val colStats = new mutable.HashMap[String, ColumnStat] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index b970be740ab51..be024adac8eb0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -434,6 +434,8 @@ private[hive] class HiveClientImpl( } val comment = properties.get("comment") + // Here we are reading statistics from Hive. + // Note that this statistics could be overridden by Spark's statistics if that's available. val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0) From f9b59abeae16088c7c4d3a475762ef6c4ad42b4b Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Mon, 29 May 2017 12:21:34 +0200 Subject: [PATCH 028/133] [SPARK-20758][SQL] Add Constant propagation optimization ## What changes were proposed in this pull request? See class doc of `ConstantPropagation` for the approach used. ## How was this patch tested? - Added unit tests Author: Tejas Patil Closes #17993 from tejasapatil/SPARK-20758_const_propagation. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/expressions.scala | 56 ++++++ .../optimizer/ConstantPropagationSuite.scala | 167 ++++++++++++++++++ .../datasources/FileSourceStrategySuite.scala | 18 +- 4 files changed, 235 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ae2f6bfa94ae7..d16689a34298a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -92,6 +92,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CombineUnions, // Constant folding and strength reduction NullPropagation(conf), + ConstantPropagation, FoldablePropagation, OptimizeIn(conf), ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8931eb2c8f3b1..51f749a8bf857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -54,6 +54,62 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding + * value in conjunctive [[Expression Expressions]] + * eg. + * {{{ + * SELECT * FROM table WHERE i = 5 AND j = i + 3 + * ==> SELECT * FROM table WHERE i = 5 AND j = 8 + * }}} + * + * Approach used: + * - Start from AND operator as the root + * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they + * don't have a `NOT` or `OR` operator in them + * - Populate a mapping of attribute => constant value by looking at all the equals predicates + * - Using this mapping, replace occurrence of the attributes with the corresponding constant values + * in the AND node. + */ +object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { + private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { + case _: Not | _: Or => true + case _ => false + }.isDefined + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f: Filter => f transformExpressionsUp { + case and: And => + val conjunctivePredicates = + splitConjunctivePredicates(and) + .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) + .filterNot(expr => containsNonConjunctionPredicates(expr)) + + val equalityPredicates = conjunctivePredicates.collect { + case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) + } + + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + + def replaceConstants(expression: Expression) = expression transform { + case a: AttributeReference => + constantsMap.get(a) match { + case Some(literal) => literal + case None => a + } + } + + and transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e) + } + } + } +} /** * Reorder associative integral-type operators and fold all constants into one. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala new file mode 100644 index 0000000000000..81d2f3667e2d0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for constant propagation in expressions. + */ +class ConstantPropagationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("ConstantPropagation", FixedPoint(10), + ColumnPruning, + ConstantPropagation, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + private val columnA = 'a.int + private val columnB = 'b.int + private val columnC = 'c.int + + test("basic test") { + val query = testRelation + .select(columnA) + .where(columnA === Add(columnB, Literal(1)) && columnB === Literal(10)) + + val correctAnswer = + testRelation + .select(columnA) + .where(columnA === Literal(11) && columnB === Literal(10)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("with combination of AND and OR predicates") { + val query = testRelation + .select(columnA) + .where( + columnA === Add(columnB, Literal(1)) && + columnB === Literal(10) && + (columnA === Add(columnC, Literal(3)) || columnB === columnC)) + .analyze + + val correctAnswer = + testRelation + .select(columnA) + .where( + columnA === Literal(11) && + columnB === Literal(10) && + (Literal(11) === Add(columnC, Literal(3)) || Literal(10) === columnC)) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates outside a `NOT` can be propagated within a `NOT`") { + val query = testRelation + .select(columnA) + .where(Not(columnA === Add(columnB, Literal(1))) && columnB === Literal(10)) + .analyze + + val correctAnswer = + testRelation + .select(columnA) + .where(Not(columnA === Literal(11)) && columnB === Literal(10)) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates inside a `NOT` should not be picked for propagation") { + val query = testRelation + .select(columnA) + .where(Not(columnB === Literal(10)) && columnA === Add(columnB, Literal(1))) + .analyze + + comparePlans(Optimize.execute(query), query) + } + + test("equality predicates outside a `OR` can be propagated within a `OR`") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(2) && + (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + val correctAnswer = testRelation + .select(columnA) + .where( + columnA === Literal(2) && + (Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("equality predicates inside a `OR` should not be picked for propagation") { + val query = testRelation + .select(columnA) + .where( + columnA === Add(columnB, Literal(2)) && + (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) + .analyze + + comparePlans(Optimize.execute(query), query) + } + + test("equality operator not immediate child of root `AND` should not be used for propagation") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(0) && + ((columnB === columnA) === (columnB === Literal(0)))) + .analyze + + val correctAnswer = testRelation + .select(columnA) + .where( + columnA === Literal(0) && + ((columnB === Literal(0)) === (columnB === Literal(0)))) + .analyze + + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("conflicting equality predicates") { + val query = testRelation + .select(columnA) + .where( + columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) + + val correctAnswer = testRelation + .select(columnA) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)) + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index fa3c69612704d..9a2dcafb5e4b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set.empty) // Only one file should be read. - checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions => assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 1, "when checking files in partition 1") assert(partitions.head.files.head.partitionValues.getInt(0) == 1, @@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi "p1=1/file1" -> 10, "p1=2/file2" -> 10)) - val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") // Filter on data only are advisory so we have to reevaluate. - assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) - // Need to evalaute filters that are not pushed down. - assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1")) // Don't reevaluate partition only filters. - assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1"))) + + val df2 = table.where("(p1 + c2) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2")) } test("bucketed table") { From ef9fd920c3241e05915c231bef50e3e51a655ce6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 29 May 2017 11:47:31 -0700 Subject: [PATCH 029/133] [SPARK-20750][SQL] Built-in SQL Function Support - REPLACE ## What changes were proposed in this pull request? This PR adds built-in SQL function `(REPLACE(, [, ])` `REPLACE()` return that string that is replaced all occurrences with given string. ## How was this patch tested? added new test suites Author: Kazuaki Ishizaki Closes #18047 from kiszk/SPARK-20750. --- .../apache/spark/unsafe/types/UTF8String.java | 9 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/stringExpressions.scala | 42 +++++++++++++++++++ .../expressions/StringExpressionsSuite.scala | 20 +++++++++ .../sql-tests/inputs/string-functions.sql | 4 ++ .../results/string-functions.sql.out | 18 +++++++- 6 files changed, 93 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5437e998c085f..40b9fc9534f44 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -835,6 +835,15 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + public UTF8String replace(UTF8String search, UTF8String replace) { + if (EMPTY_UTF8.equals(search)) { + return this; + } + String replaced = toString().replace( + search.toString(), replace.toString()); + return fromString(replaced); + } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes public UTF8String translate(Map dict) { String srcStr = this.toString(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a4c7f7a8de223..549fa0dc8bd20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -304,6 +304,7 @@ object FunctionRegistry { expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), + expression[StringReplace]("replace"), expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index aba2f5f81f831..1dbe098f96ef5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -340,6 +340,48 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate } } +/** + * Replace all occurrences with string. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, search[, replace]) - Replaces all occurrences of `search` with `replace`.", + extended = """ + Arguments: + str - a string expression + search - a string expression. If `search` is not found in `str`, `str` is returned unchanged. + replace - a string expression. If `replace` is not specified or is an empty string, nothing replaces + the string that is removed from `str`. + + Examples: + > SELECT _FUNC_('ABCabc', 'abc', 'DEF'); + ABCDEF + """) +// scalastyle:on line.size.limit +case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + def this(srcExpr: Expression, searchExpr: Expression) = { + this(srcExpr, searchExpr, Literal("")) + } + + override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = { + srcEval.asInstanceOf[UTF8String].replace( + searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (src, search, replace) => { + s"""${ev.value} = $src.replace($search, $replace);""" + }) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = srcExpr :: searchExpr :: replaceExpr :: Nil + override def prettyName: String = "replace" +} + object StringTranslate { def buildDict(matchingString: UTF8String, replaceString: UTF8String) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9ae438d568a90..4bdb43bfed8b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -372,6 +372,26 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SoundEx(Literal("!!")), "!!") } + test("replace") { + checkEvaluation( + StringReplace(Literal("replace"), Literal("pl"), Literal("123")), "re123ace") + checkEvaluation(StringReplace(Literal("replace"), Literal("pl"), Literal("")), "reace") + checkEvaluation(StringReplace(Literal("replace"), Literal(""), Literal("123")), "replace") + checkEvaluation(StringReplace(Literal.create(null, StringType), + Literal("pl"), Literal("123")), null) + checkEvaluation(StringReplace(Literal("replace"), + Literal.create(null, StringType), Literal("123")), null) + checkEvaluation(StringReplace(Literal("replace"), + Literal("pl"), Literal.create(null, StringType)), null) + // test for multiple replace + checkEvaluation(StringReplace(Literal("abcabc"), Literal("b"), Literal("12")), "a12ca12c") + checkEvaluation(StringReplace(Literal("abcdabcd"), Literal("bc"), Literal("")), "adad") + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringReplace(Literal("花花世界"), Literal("花世"), Literal("ab")), "花ab界") + // scalastyle:on + } + test("translate") { checkEvaluation( StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae") diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index da4b39cf7ddf3..e6dcea4972c18 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -8,3 +8,7 @@ select 'a' || 'b' || 'c'; -- Check if catalyst combine nested `Concat`s EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t; + +-- replace function +select replace('abc', 'b', '123'); +select replace('abc', 'b'); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 37727cadb628f..abf0cc44d6e42 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 6 -- !query 0 @@ -54,3 +54,19 @@ Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as stri == Physical Plan == *Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x] +- *Range (0, 10, step=1, splits=2) + + +-- !query 4 +select replace('abc', 'b', '123') +-- !query 4 schema +struct +-- !query 4 output +a123c + + +-- !query 5 +select replace('abc', 'b') +-- !query 5 schema +struct +-- !query 5 output +ac From c9749068ecf8e0acabdfeeceeedff0f1f73293b7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 29 May 2017 12:17:14 -0700 Subject: [PATCH 030/133] [SPARK-20907][TEST] Use testQuietly for test suites that generate long log output ## What changes were proposed in this pull request? Supress console output by using `testQuietly` in test suites ## How was this patch tested? Tested by `"SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit"` in `DataFrameSuite` Author: Kazuaki Ishizaki Closes #18135 from kiszk/SPARK-20907. --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9f691cb10f139..9ea9951c24ef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1845,7 +1845,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .count } - test("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + testQuietly("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { val N = 400 val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) From 1c7db00c74ec6a91c7eefbdba85cbf41fbe8634a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 29 May 2017 16:10:22 -0700 Subject: [PATCH 031/133] [SPARK-8184][SQL] Add additional function description for weekofyear ## What changes were proposed in this pull request? Add additional function description for weekofyear. ## How was this patch tested? manual tests ![weekofyear](https://cloud.githubusercontent.com/assets/5399861/26525752/08a1c278-4394-11e7-8988-7cbf82c3a999.gif) Author: Yuming Wang Closes #18132 from wangyum/SPARK-8184. --- .../spark/sql/catalyst/expressions/datetimeExpressions.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 43ca2cff58825..40983006c470c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -402,13 +402,15 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(date) - Returns the week of the year of the given date.", + usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", extended = """ Examples: > SELECT _FUNC_('2008-02-20'); 8 """) +// scalastyle:on line.size.limit case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) From 96a4d1d0827fc3fba83f174510b061684f0d00f7 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 29 May 2017 18:12:01 -0700 Subject: [PATCH 032/133] [SPARK-19968][SS] Use a cached instance of `KafkaProducer` instead of creating one every batch. ## What changes were proposed in this pull request? In summary, cost of recreating a KafkaProducer for writing every batch is high as it starts a lot threads and make connections and then closes them. A KafkaProducer instance is promised to be thread safe in Kafka docs. Reuse of KafkaProducer instance while writing via multiple threads is encouraged. Furthermore, I have performance improvement of 10x in latency, with this patch. ### These are times that addBatch took in ms. Without applying this patch ![with-out_patch](https://cloud.githubusercontent.com/assets/992952/23994612/a9de4a42-0a6b-11e7-9d5b-7ae18775bee4.png) ### These are times that addBatch took in ms. After applying this patch ![with_patch](https://cloud.githubusercontent.com/assets/992952/23994616/ad8c11ec-0a6b-11e7-8634-2266ebb5033f.png) ## How was this patch tested? Running distributed benchmarks comparing runs with this patch and without it. Added relevant unit tests. Author: Prashant Sharma Closes #17308 from ScrapCodes/cached-kafka-producer. --- .../sql/kafka010/CachedKafkaProducer.scala | 112 ++++++++++++++++++ .../spark/sql/kafka010/KafkaSource.scala | 14 +-- .../spark/sql/kafka010/KafkaWriteTask.scala | 17 ++- .../spark/sql/kafka010/KafkaWriter.scala | 3 +- .../kafka010/CachedKafkaProducerSuite.scala | 78 ++++++++++++ 5 files changed, 206 insertions(+), 18 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala new file mode 100644 index 0000000000000..571140b0afbc7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit} + +import com.google.common.cache._ +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} +import org.apache.kafka.clients.producer.KafkaProducer +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging + +private[kafka010] object CachedKafkaProducer extends Logging { + + private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + + private lazy val cacheExpireTimeout: Long = + SparkEnv.get.conf.getTimeAsMs("spark.kafka.producer.cache.timeout", "10m") + + private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { + override def load(config: Seq[(String, Object)]): Producer = { + val configMap = config.map(x => x._1 -> x._2).toMap.asJava + createKafkaProducer(configMap) + } + } + + private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() { + override def onRemoval( + notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = { + val paramsSeq: Seq[(String, Object)] = notification.getKey + val producer: Producer = notification.getValue + logDebug( + s"Evicting kafka producer $producer params: $paramsSeq, due to ${notification.getCause}") + close(paramsSeq, producer) + } + } + + private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] = + CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS) + .removalListener(removalListener) + .build[Seq[(String, Object)], Producer](cacheLoader) + + private def createKafkaProducer(producerConfiguration: ju.Map[String, Object]): Producer = { + val kafkaProducer: Producer = new Producer(producerConfiguration) + logDebug(s"Created a new instance of KafkaProducer for $producerConfiguration.") + kafkaProducer + } + + /** + * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't + * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep + * one instance per specified kafkaParams. + */ + private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = { + val paramsSeq: Seq[(String, Object)] = paramsToSeq(kafkaParams) + try { + guavaCache.get(paramsSeq) + } catch { + case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError) + if e.getCause != null => + throw e.getCause + } + } + + private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { + val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) + paramsSeq + } + + /** For explicitly closing kafka producer */ + private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { + val paramsSeq = paramsToSeq(kafkaParams) + guavaCache.invalidate(paramsSeq) + } + + /** Auto close on cache evict */ + private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = { + try { + logInfo(s"Closing the KafkaProducer with params: ${paramsSeq.mkString("\n")}.") + producer.close() + } catch { + case NonFatal(e) => logWarning("Error while closing kafka producer.", e) + } + } + + private def clear(): Unit = { + logInfo("Cleaning up guava cache.") + guavaCache.invalidateAll() + } + + // Intended for testing purpose only. + private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap() +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1fb0a338299b7..7ac183776e20d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String * and not use wrong broker addresses. */ private[kafka010] class KafkaSource( - sqlContext: SQLContext, - kafkaReader: KafkaOffsetReader, - executorKafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], - metadataPath: String, - startingOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) + sqlContext: SQLContext, + kafkaReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) extends Source with Logging { private val sc = sqlContext.sparkContext diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6e160cbe2db52..6fd333e2f43ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import org.apache.kafka.clients.producer.{KafkaProducer, _} -import org.apache.kafka.common.serialization.ByteArraySerializer +import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} @@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask( * Writes key value data out to topics. */ def execute(iterator: Iterator[InternalRow]): Unit = { - producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) + producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() val projectedRow = projection(currentRow) @@ -68,10 +67,10 @@ private[kafka010] class KafkaWriteTask( } def close(): Unit = { + checkForErrors() if (producer != null) { - checkForErrors - producer.close() - checkForErrors + producer.flush() + checkForErrors() producer = null } } @@ -88,7 +87,7 @@ private[kafka010] class KafkaWriteTask( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - s"must be a ${StringType}") + "must be a StringType") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -100,7 +99,7 @@ private[kafka010] class KafkaWriteTask( } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( - throw new IllegalStateException(s"Required attribute " + + throw new IllegalStateException("Required attribute " + s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") ) valueExpression.dataType match { @@ -114,7 +113,7 @@ private[kafka010] class KafkaWriteTask( Cast(valueExpression, BinaryType)), inputSchema) } - private def checkForErrors: Unit = { + private def checkForErrors(): Unit = { if (failedWrite != null) { throw failedWrite } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 61936e32fd837..0ed9d4e84d54d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -21,7 +21,6 @@ import java.{util => ju} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.types.{BinaryType, StringType} @@ -49,7 +48,7 @@ private[kafka010] object KafkaWriter extends Logging { topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( - if (topic == None) { + if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala new file mode 100644 index 0000000000000..789bffa9da126 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.ConcurrentMap + +import org.apache.kafka.clients.producer.KafkaProducer +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.sql.test.SharedSQLContext + +class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { + + type KP = KafkaProducer[Array[Byte], Array[Byte]] + + protected override def beforeEach(): Unit = { + super.beforeEach() + val clear = PrivateMethod[Unit]('clear) + CachedKafkaProducer.invokePrivate(clear()) + } + + test("Should return the cached instance on calling getOrCreate with same params.") { + val kafkaParams = new ju.HashMap[String, Object]() + kafkaParams.put("acks", "0") + // Here only host should be resolvable, it does not need a running instance of kafka server. + kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") + kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) + kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) + val producer = CachedKafkaProducer.getOrCreate(kafkaParams) + val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams) + assert(producer == producer2) + + val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap) + val map = CachedKafkaProducer.invokePrivate(cacheMap()) + assert(map.size == 1) + } + + test("Should close the correct kafka producer for the given kafkaPrams.") { + val kafkaParams = new ju.HashMap[String, Object]() + kafkaParams.put("acks", "0") + kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") + kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) + kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) + val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams) + kafkaParams.put("acks", "1") + val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams) + // With updated conf, a new producer instance should be created. + assert(producer != producer2) + + val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap) + val map = CachedKafkaProducer.invokePrivate(cacheMap()) + assert(map.size == 2) + + CachedKafkaProducer.close(kafkaParams) + val map2 = CachedKafkaProducer.invokePrivate(cacheMap()) + assert(map2.size == 1) + import scala.collection.JavaConverters._ + val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0) + assert(_producer == producer) + } +} From d797ed0ef10f3e2e4cade3fc47071839ae8c5fd4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 30 May 2017 15:40:50 +0900 Subject: [PATCH 033/133] [SPARK-20909][SQL] Add build-int SQL function - DAYOFWEEK ## What changes were proposed in this pull request? Add build-int SQL function - DAYOFWEEK ## How was this patch tested? unit tests Author: Yuming Wang Closes #18134 from wangyum/SPARK-20909. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 38 +++++++++++++++++++ .../expressions/DateExpressionsSuite.scala | 14 +++++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 10 ++++- 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 549fa0dc8bd20..8081036bed8a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -360,6 +360,7 @@ object FunctionRegistry { expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), + expression[DayOfWeek]("dayofweek"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 40983006c470c..505ed945cd68e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -402,6 +402,44 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (1 = Sunday, 2 = Monday, ..., 7 = Saturday).", + extended = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 5 + """) +// scalastyle:on line.size.limit +case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient private lazy val c = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } + + override protected def nullSafeEval(date: Any): Any = { + c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + c.get(Calendar.DAY_OF_WEEK) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val c = ctx.freshName("cal") + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 4ce68538c87a1..89d99f9678cda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -196,6 +196,20 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("DayOfWeek") { + checkEvaluation(DayOfWeek(Literal.create(null, DateType)), null) + checkEvaluation(DayOfWeek(Literal(d)), Calendar.WEDNESDAY) + checkEvaluation(DayOfWeek(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), + Calendar.WEDNESDAY) + checkEvaluation(DayOfWeek(Cast(Literal(ts), DateType, gmtId)), Calendar.FRIDAY) + checkEvaluation(DayOfWeek(Cast(Literal("2011-05-06"), DateType, gmtId)), Calendar.FRIDAY) + checkEvaluation(DayOfWeek(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), + Calendar.SATURDAY) + checkEvaluation(DayOfWeek(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), + Calendar.FRIDAY) + checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index e957f693a983f..616b6caee3f20 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -6,3 +6,5 @@ select current_date = current_date(), current_timestamp = current_timestamp(); select to_date(null), to_date('2016-12-31'), to_date('2016-12-31', 'yyyy-MM-dd'); select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('2016-12-31', 'yyyy-MM-dd'); + +select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 13e1e48b038ad..a28b91c77324b 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 +-- Number of queries: 4 -- !query 0 @@ -24,3 +24,11 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 struct -- !query 2 output NULL 2016-12-31 00:12:00 2016-12-31 00:00:00 + + +-- !query 3 +select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +7 5 7 NULL 6 From 80fb24b85ddcea768c5261e82449d673993e39af Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 30 May 2017 12:15:54 +0100 Subject: [PATCH 034/133] [MINOR] Fix some indent issues. ## What changes were proposed in this pull request? Fix some indent issues. ## How was this patch tested? existing tests. Author: Yuming Wang Closes #18133 from wangyum/IndentIssues. --- .../org/apache/spark/sql/catalyst/expressions/hash.scala | 2 +- .../spark/sql/catalyst/expressions/nullExpressions.scala | 6 +++--- .../sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- .../sql/catalyst/expressions/stringExpressions.scala | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2a5963d37f5e8..ffd0e64d86cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -524,7 +524,7 @@ abstract class InterpretedHashFunction { extended = """ Examples: > SELECT _FUNC_('Spark', array(123), 2); - -1321691492 + -1321691492 """) case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { def this(arguments: Seq[Expression]) = this(arguments, 42) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 92036b727dbbd..0866b8d791e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -116,9 +116,9 @@ case class IfNull(left: Expression, right: Expression, child: Expression) @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`, or `expr1` otherwise.", extended = """ - Examples: - > SELECT _FUNC_(2, 2); - NULL + Examples: + > SELECT _FUNC_(2, 2); + NULL """) case class NullIf(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index aa5a1b5448c6d..5418acedbef21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -99,7 +99,7 @@ abstract class StringRegexExpression extends BinaryExpression See also: Use RLIKE to match with standard regular expressions. -""") + """) case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -175,7 +175,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi See also: Use LIKE to match with simple string pattern. -""") + """) case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1dbe098f96ef5..cc4d465c5d701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -770,10 +770,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) """, extended = """ Examples: - > SELECT _FUNC_('hi', 5, '??'); - hi??? - > SELECT _FUNC_('hi', 1, '??'); - h + > SELECT _FUNC_('hi', 5, '??'); + hi??? + > SELECT _FUNC_('hi', 1, '??'); + h """) case class StringRPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { From 35b644bd03da74ee9cafd2d1626e4694d473236d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 May 2017 06:28:43 -0700 Subject: [PATCH 035/133] [SPARK-20916][SQL] Improve error message for unaliased subqueries in FROM clause ## What changes were proposed in this pull request? We changed the parser to reject unaliased subqueries in the FROM clause in SPARK-20690. However, the error message that we now give isn't very helpful: scala> sql("""SELECT x FROM (SELECT 1 AS x)""") org.apache.spark.sql.catalyst.parser.ParseException: mismatched input 'FROM' expecting {, 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE'}(line 1, pos 9) We should modify the parser to throw a more clear error for such queries: scala> sql("""SELECT x FROM (SELECT 1 AS x)""") org.apache.spark.sql.catalyst.parser.ParseException: The unaliased subqueries in the FROM clause are not supported.(line 1, pos 14) ## How was this patch tested? Modified existing tests to reflect this change. Author: Liang-Chi Hsieh Closes #18141 from viirya/SPARK-20916. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 7 +++ .../sql/catalyst/parser/PlanParserSuite.scala | 6 +- .../inputs/subquery/subquery-in-from.sql | 14 +++++ .../results/subquery/subquery-in-from.sql.out | 62 +++++++++++++++++++ 5 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 547013c23fd78..4584aea6196a6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -473,7 +473,7 @@ identifierComment relationPrimary : tableIdentifier sample? tableAlias #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier) #aliasedQuery + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 | functionTable #tableValuedFunction diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5f34d0777d5a1..4eb5560155781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -749,6 +749,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * hooks. */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + // The unaliased subqueries in the FROM clause are disallowed. Instead of rejecting it in + // parser rules, we handle it here in order to provide better error message. + if (ctx.strictIdentifier == null) { + throw new ParseException("The unaliased subqueries in the FROM clause are not supported.", + ctx) + } + aliasPlan(ctx.strictIdentifier, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 7a5357eef8f94..3a26adaef9db0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -448,13 +448,15 @@ class PlanParserSuite extends PlanTest { } test("aliased subquery") { + val errMsg = "The unaliased subqueries in the FROM clause are not supported" + assertEqual("select a from (select id as a from t0) tt", table("t0").select('id.as("a")).as("tt").select('a)) - intercept("select a from (select id as a from t0)", "mismatched input") + intercept("select a from (select id as a from t0)", errMsg) assertEqual("from (select id as a from t0) tt select a", table("t0").select('id.as("a")).as("tt").select('a)) - intercept("from (select id as a from t0) select a", "extraneous input 'a'") + intercept("from (select id as a from t0) select a", errMsg) } test("scalar sub-query") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql new file mode 100644 index 0000000000000..1273b56b6344b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/subquery-in-from.sql @@ -0,0 +1,14 @@ +-- Aliased subqueries in FROM clause +SELECT * FROM (SELECT * FROM testData) AS t WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) AS t SELECT *; + +-- Optional `AS` keyword +SELECT * FROM (SELECT * FROM testData) t WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) t SELECT *; + +-- Disallow unaliased subqueries in FROM clause +SELECT * FROM (SELECT * FROM testData) WHERE key = 1; + +FROM (SELECT * FROM testData WHERE key = 1) SELECT *; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out new file mode 100644 index 0000000000000..14553557d1ffc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/subquery-in-from.sql.out @@ -0,0 +1,62 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT * FROM (SELECT * FROM testData) AS t WHERE key = 1 +-- !query 0 schema +struct +-- !query 0 output +1 1 + + +-- !query 1 +FROM (SELECT * FROM testData WHERE key = 1) AS t SELECT * +-- !query 1 schema +struct +-- !query 1 output +1 1 + + +-- !query 2 +SELECT * FROM (SELECT * FROM testData) t WHERE key = 1 +-- !query 2 schema +struct +-- !query 2 output +1 1 + + +-- !query 3 +FROM (SELECT * FROM testData WHERE key = 1) t SELECT * +-- !query 3 schema +struct +-- !query 3 output +1 1 + + +-- !query 4 +SELECT * FROM (SELECT * FROM testData) WHERE key = 1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +The unaliased subqueries in the FROM clause are not supported.(line 1, pos 14) + +== SQL == +SELECT * FROM (SELECT * FROM testData) WHERE key = 1 +--------------^^^ + + +-- !query 5 +FROM (SELECT * FROM testData WHERE key = 1) SELECT * +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.catalyst.parser.ParseException + +The unaliased subqueries in the FROM clause are not supported.(line 1, pos 5) + +== SQL == +FROM (SELECT * FROM testData WHERE key = 1) SELECT * +-----^^^ From ff5676b01ffd8adfe753cb749582579cbd496e7f Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 May 2017 01:02:19 +0800 Subject: [PATCH 036/133] [SPARK-20899][PYSPARK] PySpark supports stringIndexerOrderType in RFormula ## What changes were proposed in this pull request? PySpark supports stringIndexerOrderType in RFormula as in #17967. ## How was this patch tested? docstring test Author: actuaryzhang Closes #18122 from actuaryzhang/PythonRFormula. --- python/pyspark/ml/feature.py | 33 ++++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 13 +++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 955bc9768ce77..77de1cc18246d 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3043,26 +3043,35 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM "Force to index label whether it is numeric or string", typeConverter=TypeConverters.toBoolean) + stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType", + "How to order categories of a string feature column used by " + + "StringIndexer. The last category after ordering is dropped " + + "when encoding strings. Supported options: frequencyDesc, " + + "frequencyAsc, alphabetDesc, alphabetAsc. The default value " + + "is frequencyDesc. When the ordering is set to alphabetDesc, " + + "RFormula drops the same category as R when encoding strings.", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): """ __init__(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False) + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self._setDefault(forceIndexLabel=False) + self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") def setParams(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): """ setParams(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False) + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") Sets params for RFormula. """ kwargs = self._input_kwargs @@ -3096,6 +3105,20 @@ def getForceIndexLabel(self): """ return self.getOrDefault(self.forceIndexLabel) + @since("2.3.0") + def setStringIndexerOrderType(self, value): + """ + Sets the value of :py:attr:`stringIndexerOrderType`. + """ + return self._set(stringIndexerOrderType=value) + + @since("2.3.0") + def getStringIndexerOrderType(self): + """ + Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringIndexerOrderType) + def _create_model(self, java_model): return RFormulaModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0daf29d59cb74..17a39472e1fe5 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -538,6 +538,19 @@ def test_rformula_force_index_label(self): transformedDF2 = model2.transform(df) self.assertEqual(transformedDF2.head().label, 0.0) + def test_rformula_string_indexer_order_type(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") + self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') + transformedDF = rf.fit(df).transform(df) + observed = transformedDF.select("features").collect() + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] + for i in range(0, len(expected)): + self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + class HasInducedError(Params): From 4d57981cfb18e7500cde6c03ae46c7c9b697d064 Mon Sep 17 00:00:00 2001 From: Arman Date: Tue, 30 May 2017 11:09:21 -0700 Subject: [PATCH 037/133] [SPARK-19236][CORE] Added createOrReplaceGlobalTempView method ## What changes were proposed in this pull request? Added the createOrReplaceGlobalTempView method for dataset Author: Arman Closes #16598 from arman1371/patch-1. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0e7415890e216..b98a705c73699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2747,6 +2747,21 @@ class Dataset[T] private[sql]( def createGlobalTempView(viewName: String): Unit = withPlan { createTempViewCommand(viewName, replace = false, global = true) } + + /** + * Creates or replaces a global temporary view using the given name. The lifetime of this + * temporary view is tied to this Spark application. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @group basic + */ + def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan { + createTempViewCommand(viewName, replace = true, global = true) + } private def createTempViewCommand( viewName: String, From de953c214c025fbc7b0e94f85625d72091e7257e Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 30 May 2017 14:02:33 -0500 Subject: [PATCH 038/133] [SPARK-20333] HashPartitioner should be compatible with num of child RDD's partitions. ## What changes were proposed in this pull request? Fix test "don't submit stage until its dependencies map outputs are registered (SPARK-5259)" , "run trivial shuffle with out-of-band executor failure and retry", "reduce tasks should be placed locally with map output" in DAGSchedulerSuite. Author: jinxing Closes #17634 from jinxing64/SPARK-20333. --- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a10941b579fe2..67145e7445061 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1277,10 +1277,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1583,7 +1583,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou */ test("run trivial shuffle with out-of-band executor failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) @@ -1791,7 +1791,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou test("reduce tasks should be placed locally with map output") { // Create a shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) From 798a04fd7645224b26a05b0e17e565daeeff3b64 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 30 May 2017 12:22:23 -0700 Subject: [PATCH 039/133] HOTFIX: fix Scalastyle break introduced in 4d57981cfb18e7500cde6c03ae46c7c9b697d064 --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b98a705c73699..1cd6fda5edc87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2747,7 +2747,7 @@ class Dataset[T] private[sql]( def createGlobalTempView(viewName: String): Unit = withPlan { createTempViewCommand(viewName, replace = false, global = true) } - + /** * Creates or replaces a global temporary view using the given name. The lifetime of this * temporary view is tied to this Spark application. From 4bb6a53ebd06de3de97139a2dbc7c85fc3aa3e66 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 30 May 2017 14:06:19 -0700 Subject: [PATCH 040/133] [SPARK-20924][SQL] Unable to call the function registered in the not-current database ### What changes were proposed in this pull request? We are unable to call the function registered in the not-current database. ```Scala sql("CREATE DATABASE dAtABaSe1") sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'") sql("SELECT dAtABaSe1.test_avg(1)") ``` The above code returns an error: ``` Undefined function: 'dAtABaSe1.test_avg'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 ``` This PR is to fix the above issue. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #18146 from gatorsmile/qualifiedFunction. --- .../sql/catalyst/catalog/SessionCatalog.scala | 17 ++++++++-------- .../spark/sql/hive/HiveSessionCatalog.scala | 6 +++--- .../sql/hive/execution/HiveUDFSuite.scala | 20 +++++++++++++++++++ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f6653d384fe1d..a78440df4f3e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1105,8 +1105,9 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } - protected def failFunctionLookup(name: String): Nothing = { - throw new NoSuchFunctionException(db = currentDb, func = name) + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { + throw new NoSuchFunctionException( + db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) } /** @@ -1128,7 +1129,7 @@ class SessionCatalog( qualifiedName.database.orNull, qualifiedName.identifier) } else { - failFunctionLookup(name.funcName) + failFunctionLookup(name) } } } @@ -1158,8 +1159,8 @@ class SessionCatalog( } // If the name itself is not qualified, add the current database to it. - val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) - val qualifiedName = name.copy(database = database) + val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val qualifiedName = name.copy(database = Some(database)) if (functionRegistry.functionExists(qualifiedName.unquotedString)) { // This function has been already loaded into the function registry. @@ -1172,10 +1173,10 @@ class SessionCatalog( // in the metastore). We need to first put the function in the FunctionRegistry. // TODO: why not just check whether the function exists first? val catalogFunction = try { - externalCatalog.getFunction(currentDb, name.funcName) + externalCatalog.getFunction(database, name.funcName) } catch { - case e: AnalysisException => failFunctionLookup(name.funcName) - case e: NoSuchPermanentFunctionException => failFunctionLookup(name.funcName) + case _: AnalysisException => failFunctionLookup(name) + case _: NoSuchPermanentFunctionException => failFunctionLookup(name) } loadFunctionResources(catalogFunction.resources) // Please note that qualifiedName is provided by the user. However, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 377d4f2473c58..6227e780c0409 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -140,7 +140,7 @@ private[sql] class HiveSessionCatalog( // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { - failFunctionLookup(funcName.unquotedString) + failFunctionLookup(funcName) } // TODO: Remove this fallback path once we implement the list of fallback functions @@ -148,12 +148,12 @@ private[sql] class HiveSessionCatalog( val functionInfo = { try { Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( - failFunctionLookup(funcName.unquotedString)) + failFunctionLookup(funcName)) } catch { // If HiveFunctionRegistry.getFunctionInfo throws an exception, // we are failing to load a Hive builtin function, which means that // the given function is not a Hive builtin function. - case NonFatal(e) => failFunctionLookup(funcName.unquotedString) + case NonFatal(e) => failFunctionLookup(funcName) } } val className = functionInfo.getFunctionClass.getName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4446af2e75e00..8fcbad58350f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -590,6 +591,25 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("Call the function registered in the not-current database") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + withDatabase("dAtABaSe1") { + sql("CREATE DATABASE dAtABaSe1") + withUserDefinedFunction("dAtABaSe1.test_avg" -> false) { + sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT dAtABaSe1.test_avg(1)"), Row(1.0)) + } + val message = intercept[AnalysisException] { + sql("SELECT dAtABaSe1.unknownFunc(1)") + }.getMessage + assert(message.contains("Undefined function: 'unknownFunc'") && + message.contains("nor a permanent function registered in the database 'dAtABaSe1'")) + } + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From 0896845cf65fb4ccdaebda40f21745d3fa0bd3ba Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 30 May 2017 15:06:30 -0700 Subject: [PATCH 041/133] update namespace --- R/pkg/NAMESPACE | 1 + 1 file changed, 1 insertion(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5c074d3c0fd40..6dc26b3433250 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -356,6 +356,7 @@ exportMethods("%<=>%", "to_utc_timestamp", "translate", "trim", + "trunc”, "unbase64", "unhex", "unix_timestamp", From 7af4490794e1e2dc654ae5c37d0805f249931589 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 30 May 2017 15:08:45 -0700 Subject: [PATCH 042/133] fix encoding issue --- R/pkg/NAMESPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 6dc26b3433250..2f297d1a4efd5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -356,7 +356,7 @@ exportMethods("%<=>%", "to_utc_timestamp", "translate", "trim", - "trunc”, + "trunc", "unbase64", "unhex", "unix_timestamp", From fa757ee1d41396ad8734a3f2dd045bb09bc82a2e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 30 May 2017 15:33:06 -0700 Subject: [PATCH 043/133] [SPARK-20883][SPARK-20376][SS] Refactored StateStore APIs and added conf to choose implementation ## What changes were proposed in this pull request? A bunch of changes to the StateStore APIs and implementation. Current state store API has a bunch of problems that causes too many transient objects causing memory pressure. - `StateStore.get(): Option` forces creation of Some/None objects for every get. Changed this to return the row or null. - `StateStore.iterator(): (UnsafeRow, UnsafeRow)` forces creation of new tuple for each record returned. Changed this to return a UnsafeRowTuple which can be reused across records. - `StateStore.updates()` requires the implementation to keep track of updates, while this is used minimally (only by Append mode in streaming aggregations). Removed updates() and updated StateStoreSaveExec accordingly. - `StateStore.filter(condition)` and `StateStore.remove(condition)` has been merge into a single API `getRange(start, end)` which allows a state store to do optimized range queries (i.e. avoid full scans). Stateful operators have been updated accordingly. - Removed a lot of unnecessary row copies Each operator copied rows before calling StateStore.put() even if the implementation does not require it to be copied. It is left up to the implementation on whether to copy the row or not. Additionally, - Added a name to the StateStoreId so that each operator+partition can use multiple state stores (different names) - Added a configuration that allows the user to specify which implementation to use. - Added new metrics to understand the time taken to update keys, remove keys and commit all changes to the state store. These metrics will be visible on the plan diagram in the SQL tab of the UI. - Refactored unit tests such that they can be reused to test any implementation of StateStore. ## How was this patch tested? Old and new unit tests Author: Tathagata Das Closes #18107 from tdas/SPARK-20376. --- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../FlatMapGroupsWithStateExec.scala | 39 +- .../state/HDFSBackedStateStoreProvider.scala | 218 +++---- .../streaming/state/StateStore.scala | 163 ++++-- .../streaming/state/StateStoreConf.scala | 28 +- .../streaming/state/StateStoreRDD.scala | 11 +- .../execution/streaming/state/package.scala | 11 +- .../streaming/statefulOperators.scala | 142 +++-- .../streaming/state/StateStoreRDDSuite.scala | 41 +- .../streaming/state/StateStoreSuite.scala | 534 ++++++++---------- .../FlatMapGroupsWithStateSuite.scala | 40 +- .../spark/sql/streaming/StreamSuite.scala | 45 ++ 12 files changed, 695 insertions(+), 588 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c5d69c204642e..c6f5cf641b8d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -552,6 +552,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = + buildConf("spark.sql.streaming.stateStore.providerClass") + .internal() + .doc( + "The class used to manage state data in stateful streaming queries. This class must " + + "be a subclass of StateStoreProvider, and must have a zero-arg constructor.") + .stringConf + .createOptional + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") .internal() @@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 3ceb4cf84a413..2aad8701a4eca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, groupingAttributes.toStructType, stateAttributes.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) @@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.filter { case (_, stateRow) => - val timeoutTimestamp = getTimeoutTimestamp(stateRow) + val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) + timingOutKeys.flatMap { rowPair => + callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty } @@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec( * Call the user function on a key's data, update the state store, and return the return data * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. + * + * @param keyRow Row representing the key, cannot be null + * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty + * @param prevStateRow Row representing the previous state, can be null + * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow], - prevStateRowOption: Option[UnsafeRow], + prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = getStateObj(prevStateRowOption) + val stateObj = getStateObj(prevStateRow) val keyedState = GroupStateImpl.createForStreaming( - stateObjOption, + Option(stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, @@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec( numUpdatedStateRows += 1 } else { - val previousTimeoutTimestamp = prevStateRowOption match { - case Some(row) => getTimeoutTimestamp(row) - case None => NO_TIMESTAMP - } + val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { - prevStateRowOption.orNull + prevStateRow } val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp @@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException("Attempting to write empty state") } setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow.copy(), stateRowToWrite.copy()) + store.put(keyRow, stateRowToWrite) numUpdatedStateRows += 1 } } @@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec( } /** Returns the state as Java object if defined */ - def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { - stateRowOption.map(getStateObjFromRow) + def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow != null) getStateObjFromRow(stateRow) else null } /** Returns the row for an updated state */ def getStateRow(obj: Any): UnsafeRow = { + assert(obj != null) getStateRowFromObj(obj) } /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP + if (isTimeoutEnabled && stateRow != null) { + stateRow.getLong(timeoutTimestampIndex) + } else NO_TIMESTAMP } /** Set the timestamp in a state row */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index fb2bf47d6e83b..67d86daf10812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -67,13 +67,7 @@ import org.apache.spark.util.Utils * to ensure re-executed RDD operations re-apply updates on the correct past version of the * store. */ -private[state] class HDFSBackedStateStoreProvider( - val id: StateStoreId, - keySchema: StructType, - valueSchema: StructType, - storeConf: StateStoreConf, - hadoopConf: Configuration - ) extends StateStoreProvider with Logging { +private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging { // ConcurrentHashMap is used because it generates fail-safe iterators on filtering // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in @@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider( private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) - private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() - @volatile private var state: STATE = UPDATING @volatile private var finalDeltaFile: Path = null override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - override def get(key: UnsafeRow): Option[UnsafeRow] = { - Option(mapToUpdate.get(key)) - } - - override def filter( - condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - mapToUpdate - .entrySet - .asScala - .iterator - .filter { entry => condition(entry.getKey, entry.getValue) } - .map { entry => (entry.getKey, entry.getValue) } + override def get(key: UnsafeRow): UnsafeRow = { + mapToUpdate.get(key) } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") - - val isNewKey = !mapToUpdate.containsKey(key) - mapToUpdate.put(key, value) - - Option(allUpdates.get(key)) match { - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added already, keep it marked as added - allUpdates.put(key, ValueAdded(key, value)) - case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) => - // Value existed in previous version and updated/removed, mark it as updated - allUpdates.put(key, ValueUpdated(key, value)) - case None => - // There was no prior update, so mark this as added or updated according to its presence - // in previous version. - val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) - allUpdates.put(key, update) - } - writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + val keyCopy = key.copy() + val valueCopy = value.copy() + mapToUpdate.put(keyCopy, valueCopy) + writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) } - /** Remove keys that match the following condition */ - override def remove(condition: UnsafeRow => Boolean): Unit = { + override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") - val entryIter = mapToUpdate.entrySet().iterator() - while (entryIter.hasNext) { - val entry = entryIter.next - if (condition(entry.getKey)) { - val value = entry.getValue - val key = entry.getKey - entryIter.remove() - - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + val prevValue = mapToUpdate.remove(key) + if (prevValue != null) { + writeRemoveToDeltaFile(tempDeltaFileStream, key) } } - /** Remove a single key. */ - override def remove(key: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") - if (mapToUpdate.containsKey(key)) { - val value = mapToUpdate.remove(key) - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } + override def getRange( + start: Option[UnsafeRow], + end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + verify(state == UPDATING, "Cannot getRange after already committed or aborted") + iterator() } /** Commit all the updates that have been made to the store, and return the new version. */ @@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider( * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. */ - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, - "Cannot get iterator of store data before committing or after aborting") - HDFSBackedStateStoreProvider.this.iterator(newVersion) - } - - /** - * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing all the updates made in the current thread. - */ - override def updates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, - "Cannot get iterator of updates before committing or after aborting") - allUpdates.values().asScala.toIterator + override def iterator(): Iterator[UnsafeRowPair] = { + val unsafeRowPair = new UnsafeRowPair() + mapToUpdate.entrySet.asScala.iterator.map { entry => + unsafeRowPair.withRows(entry.getKey, entry.getValue) + } } override def numKeys(): Long = mapToUpdate.size() @@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider( store } + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): Unit = { + this.stateStoreId = stateStoreId + this.keySchema = keySchema + this.valueSchema = valueSchema + this.storeConf = storeConf + this.hadoopConf = hadoopConf + fs.mkdirs(baseDir) + } + + override def id: StateStoreId = stateStoreId + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ override def doMaintenance(): Unit = { try { @@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider( } } + override def close(): Unit = { + loadedMaps.values.foreach(_.clear()) + } + override def toString(): String = { s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } - /* Internal classes and methods */ + /* Internal fields and methods */ - private val loadedMaps = new mutable.HashMap[Long, MapType] - private val baseDir = - new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") - private val fs = baseDir.getFileSystem(hadoopConf) - private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + @volatile private var stateStoreId: StateStoreId = _ + @volatile private var keySchema: StructType = _ + @volatile private var valueSchema: StructType = _ + @volatile private var storeConf: StateStoreConf = _ + @volatile private var hadoopConf: Configuration = _ - initialize() + private lazy val loadedMaps = new mutable.HashMap[Long, MapType] + private lazy val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) @@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider( * Get iterator of all the data of the latest version of the store. * Note that this will look up the files to determined the latest known version. */ - private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized { val versionsInFiles = fetchFiles().map(_.version).toSet val versionsLoaded = loadedMaps.keySet val allKnownVersions = versionsInFiles ++ versionsLoaded + val unsafeRowTuple = new UnsafeRowPair() if (allKnownVersions.nonEmpty) { - loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry => + unsafeRowTuple.withRows(entry.getKey, entry.getValue) } } else Iterator.empty } - /** Get iterator of a specific version of the store */ - private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { - loadMap(version).entrySet().iterator().asScala.map { x => - (x.getKey, x.getValue) - } - } - - /** Initialize the store provider */ - private def initialize(): Unit = { - try { - fs.mkdirs(baseDir) - } catch { - case e: IOException => - throw new IllegalStateException( - s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e) - } - } - /** Load the required version of the map data from the backing files */ private def loadMap(version: Long): MapType = { if (version <= 0) return new MapType @@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider( } } - private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { - - def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - val valueBytes = value.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(valueBytes.size) - output.write(valueBytes) - } - - def writeRemove(key: UnsafeRow): Unit = { - val keyBytes = key.getBytes() - output.writeInt(keyBytes.size) - output.write(keyBytes) - output.writeInt(-1) - } + private def writeUpdateToDeltaFile( + output: DataOutputStream, + key: UnsafeRow, + value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } - update match { - case ValueAdded(key, value) => - writeUpdate(key, value) - case ValueUpdated(key, value) => - writeUpdate(key, value) - case ValueRemoved(key, value) => - writeRemove(key) - } + private def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) } private def finalizeDeltaFile(output: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index eaa558eb6d0ed..29c456f86e1ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ThreadUtils - - -/** Unique identifier for a [[StateStore]] */ -case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) +import org.apache.spark.util.{ThreadUtils, Utils} /** - * Base trait for a versioned key-value store used for streaming aggregations + * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific + * version of state data, and such instances are created through a [[StateStoreProvider]]. */ trait StateStore { @@ -47,50 +44,54 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** Get the current value of a key. */ - def get(key: UnsafeRow): Option[UnsafeRow] - /** - * Return an iterator of key-value pairs that satisfy a certain condition. - * Note that the iterator must be fail-safe towards modification to the store, that is, - * it must be based on the snapshot of store the time of this call, and any change made to the - * store while iterating through iterator should not cause the iterator to fail or have - * any affect on the values in the iterator. + * Get the current value of a non-null key. + * @return a non-null row if the key exists in the store, otherwise null. */ - def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + def get(key: UnsafeRow): UnsafeRow - /** Put a new value for a key. */ + /** + * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in + * the params can be reused, and must make copies of the data as needed for persistence. + */ def put(key: UnsafeRow, value: UnsafeRow): Unit /** - * Remove keys that match the following condition. + * Remove a single non-null key. */ - def remove(condition: UnsafeRow => Boolean): Unit + def remove(key: UnsafeRow): Unit /** - * Remove a single key. + * Get key value pairs with optional approximate `start` and `end` extents. + * If the State Store implementation maintains indices for the data based on the optional + * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use + * `start` and `end` to make a best-effort scan over the data. Default implementation returns + * the full data scan iterator, which is correct but inefficient. Custom implementations must + * ensure that updates (puts, removes) can be made while iterating over this iterator. + * + * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value. + * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value. + * @return An iterator of key-value pairs that is guaranteed not miss any key between start and + * end, both inclusive. */ - def remove(key: UnsafeRow): Unit + def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = { + iterator() + } /** * Commit all the updates that have been made to the store, and return the new version. + * Implementations should ensure that no more updates (puts, removes) can be after a commit in + * order to avoid incorrect usage. */ def commit(): Long - /** Abort all the updates that have been made to the store. */ - def abort(): Unit - /** - * Iterator of store data after a set of updates have been committed. - * This can be called only after committing all the updates made in the current thread. + * Abort all the updates that have been made to the store. Implementations should ensure that + * no more updates (puts, removes) can be after an abort in order to avoid incorrect usage. */ - def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + def abort(): Unit - /** - * Iterator of the updates that have been committed. - * This can be called only after committing all the updates made in the current thread. - */ - def updates(): Iterator[StoreUpdate] + def iterator(): Iterator[UnsafeRowPair] /** Number of keys in the state store */ def numKeys(): Long @@ -102,28 +103,98 @@ trait StateStore { } -/** Trait representing a provider of a specific version of a [[StateStore]]. */ +/** + * Trait representing a provider that provide [[StateStore]] instances representing + * versions of state data. + * + * The life cycle of a provider and its provide stores are as follows. + * + * - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when + * the first batch of a streaming query is executed on the executor. All subsequent batches reuse + * this provider instance until the query is stopped. + * + * - Every batch of streaming data request a specific version of the state data by invoking + * `getStore(version)` which returns an instance of [[StateStore]] through which the required + * version of the data can be accessed. It is the responsible of the provider to populate + * this store with context information like the schema of keys and values, etc. + * + * - After the streaming query is stopped, the created provider instances are lazily disposed off. + */ trait StateStoreProvider { - /** Get the store with the existing version. */ + /** + * Initialize the provide with more contextual information from the SQL operator. + * This method will be called first after creating an instance of the StateStoreProvider by + * reflection. + * + * @param stateStoreId Id of the versioned StateStores that this provider will generate + * @param keySchema Schema of keys to be stored + * @param valueSchema Schema of value to be stored + * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by + * which the StateStore implementation could index the data. + * @param storeConfs Configurations used by the StateStores + * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data + */ + def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyIndexOrdinal: Option[Int], // for sorting the data by their keys + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit + + /** + * Return the id of the StateStores this provider will generate. + * Should be the same as the one passed in init(). + */ + def id: StateStoreId + + /** Called when the provider instance is unloaded from the executor */ + def close(): Unit + + /** Return an instance of [[StateStore]] representing state data of the given version */ def getStore(version: Long): StateStore - /** Optional method for providers to allow for background maintenance */ + /** Optional method for providers to allow for background maintenance (e.g. compactions) */ def doMaintenance(): Unit = { } } - -/** Trait representing updates made to a [[StateStore]]. */ -sealed trait StoreUpdate { - def key: UnsafeRow - def value: UnsafeRow +object StateStoreProvider { + /** + * Return a provider instance of the given provider class. + * The instance will be already initialized. + */ + def instantiate( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], // for sorting the data + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStoreProvider = { + val providerClass = storeConf.providerClass.map(Utils.classForName) + .getOrElse(classOf[HDFSBackedStateStoreProvider]) + val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider] + provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + provider + } } -case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate -case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Unique identifier for a bunch of keyed state data. */ +case class StateStoreId( + checkpointLocation: String, + operatorId: Long, + partitionId: Int, + name: String = "") -case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate +/** Mutable, and reusable class for representing a pair of UnsafeRows. */ +class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) { + def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = { + this.key = key + this.value = value + this + } +} /** @@ -185,6 +256,7 @@ object StateStore extends Logging { storeId: StateStoreId, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], version: Long, storeConf: StateStoreConf, hadoopConf: Configuration): StateStore = { @@ -193,7 +265,9 @@ object StateStore extends Logging { startMaintenanceIfNeeded() val provider = loadedProviders.getOrElseUpdate( storeId, - new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + StateStoreProvider.instantiate( + storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + ) reportActiveStoreInstance(storeId) provider } @@ -202,7 +276,7 @@ object StateStore extends Logging { /** Unload a state store provider */ def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { - loadedProviders.remove(storeId) + loadedProviders.remove(storeId).foreach(_.close()) } /** Whether a state store provider is loaded or not */ @@ -216,6 +290,7 @@ object StateStore extends Logging { /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { + loadedProviders.keySet.foreach { key => unload(key) } loadedProviders.clear() _coordRef = null if (maintenanceTask != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index acfaa8e5eb3c4..bab297c7df594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +class StateStoreConf(@transient private val sqlConf: SQLConf) + extends Serializable { def this() = this(new SQLConf) - val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot - - val minVersionsToRetain = conf.minBatchesToRetain + /** + * Minimum number of delta files in a chain after which HDFSBackedStateStore will + * consider generating a snapshot. + */ + val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot + + /** Minimum versions a State Store implementation should retain to allow rollbacks */ + val minVersionsToRetain: Int = sqlConf.minBatchesToRetain + + /** + * Optional fully qualified name of the subclass of [[StateStoreProvider]] + * managing state data. That is, the implementation of the State Store to use. + */ + val providerClass: Option[String] = sqlConf.stateStoreProviderClass + + /** + * Additional configurations related to state store. This will capture all configs in + * SQLConf that start with `spark.sql.streaming.stateStore.` */ + val confs: Map[String, String] = + sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) } -private[streaming] object StateStoreConf { +object StateStoreConf { val empty = new StateStoreConf() def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index e16dda8a5b564..b744c25dc97a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) extends RDD[U](dataRDD) { @@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( private val storeConf = new StateStoreConf(sessionState.conf) // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = dataRDD.context.broadcast( + private val hadoopConfBroadcast = dataRDD.context.broadcast( new SerializableConfiguration(sessionState.newHadoopConf())) override protected def getPartitions: Array[Partition] = dataRDD.partitions override def getPreferredLocations(partition: Partition): Seq[String] = { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) storeCoordinator.flatMap(_.getLocation(storeId)).toSeq } override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName) store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + storeId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 589042afb1e52..228fe86d59940 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -34,17 +34,21 @@ package object state { sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, - valueSchema: StructType)( + valueSchema: StructType, + indexOrdinal: Option[Int])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator))( storeUpdateFunction) @@ -54,9 +58,11 @@ package object state { private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, + storeName: String, storeVersion: Long, keySchema: StructType, valueSchema: StructType, + indexOrdinal: Option[Int], sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { @@ -69,14 +75,17 @@ package object state { }) cleanedF(store, iter) } + new StateStoreRDD( dataRDD, wrappedF, checkpointLocation, operatorId, + storeName, storeVersion, keySchema, valueSchema, + indexOrdinal, sessionState, storeCoordinator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8dbda298c87bc..3e57f3fbada32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.TimeUnit._ + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ @@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { +trait StateStoreWriter extends StatefulOperator { self: SparkPlan => + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), + "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"), + "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"), + "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes") + ) + + /** Records the duration of running `body` for the next query progress update. */ + protected def timeTakenMs(body: => Unit): Long = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + math.max(NANOSECONDS.toMillis(endTime - startTime), 0) + } } /** An operator that supports watermark. */ @@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode { /** Predicate based on the child output that matches data older than the watermark. */ lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) + + protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + store.getRange(None, None).foreach { rowPair => + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + } + } + } + } } /** @@ -126,9 +150,11 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, + storeName = "default", storeVersion = getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) @@ -136,7 +162,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: savedState.toSeq + row +: Option(savedState).toSeq } } } @@ -165,54 +191,88 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } + } + allRemovalsTimeMs += 0 + commitTimeMs += timeTakenMs { + store.commit() } - store.commit() numTotalStateRows += store.numKeys() - store.iterator().map { case (k, v) => + store.iterator().map { rowPair => numOutputRows += 1 - v.asInstanceOf[InternalRow] + rowPair.value } // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed case Some(Append) => - while (iter.hasNext) { - val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - numUpdatedStateRows += 1 + allUpdatesTimeMs += timeTakenMs { + val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) + while (filteredIter.hasNext) { + val row = filteredIter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numUpdatedStateRows += 1 + } } - // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicateForKeys.get.eval _) - store.commit() + val removalStartTimeNs = System.nanoTime + val rangeIter = store.getRange(None, None) + + new NextIterator[InternalRow] { + override protected def getNext(): InternalRow = { + var removedValueRow: InternalRow = null + while(rangeIter.hasNext && removedValueRow == null) { + val rowPair = rangeIter.next() + if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + store.remove(rowPair.key) + removedValueRow = rowPair.value + } + } + if (removedValueRow == null) { + finished = true + null + } else { + removedValueRow + } + } - numTotalStateRows += store.numKeys() - store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed => - numOutputRows += 1 - removed.value.asInstanceOf[InternalRow] + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + numTotalStateRows += store.numKeys() + } } // Update and output modified rows from the StateStore. case Some(Update) => + val updatesStartTimeNs = System.nanoTime + new Iterator[InternalRow] { // Filter late date using watermark if specified @@ -223,11 +283,11 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + // Remove old aggregates if watermark specified - if (watermarkPredicateForKeys.nonEmpty) { - store.remove(watermarkPredicateForKeys.get.eval _) - } - store.commit() + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() false } else { @@ -238,7 +298,7 @@ case class StateStoreSaveExec( override def next(): InternalRow = { val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key.copy(), row.copy()) + store.put(key, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -273,27 +333,34 @@ case class StreamingDeduplicateExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, getStateId.operatorId, + storeName = "default", getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, + indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } + val updatesStartTimeNs = System.nanoTime + val result = baseIterator.filter { r => val row = r.asInstanceOf[UnsafeRow] val key = getKey(row) val value = store.get(key) - if (value.isEmpty) { - store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + if (value == null) { + store.put(key, StreamingDeduplicateExec.EMPTY_ROW) numUpdatedStateRows += 1 numOutputRows += 1 true @@ -304,8 +371,9 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) - store.commit() + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } numTotalStateRows += store.numKeys() }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index bd197be655d58..4a1a089af54c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + import StateStoreTestsHelper._ + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) - private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) - import StateStoreSuite._ - after { StateStore.stop() } @@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn storeVersion: Int): RDD[(String, Int)] = { implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment) } // Generate RDDs and state store data @@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { val resIterator = iter.map { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) val newValue = oldValue + 1 store.put(key, intToRow(newValue)) (s, newValue) @@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn iter: Iterator[String]): Iterator[(String, Option[Int])] = { iter.map { s => val key = stringToRow(s) - val value = store.get(key).map(rowToInt) + val value = Option(store.get(key)).map(rowToInt) (s, value) } } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)( + iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") - coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2") assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)( + increment) require(rdd.partitions.length === 2) assert( @@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => val key = stringToRow(s) - val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0) store.put(key, intToRow(oldValue + 1)) } store.commit() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index cc09b2d5b7763..af2b9f1c11fb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -40,15 +40,15 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { +class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] + with BeforeAndAfter with PrivateMethodTester { type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] import StateStoreCoordinatorSuite._ - import StateStoreSuite._ + import StateStoreTestsHelper._ - private val tempDir = Utils.createTempDir().toString - private val keySchema = StructType(Seq(StructField("key", StringType, true))) - private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + val keySchema = StructType(Seq(StructField("key", StringType, true))) + val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) before { StateStore.stop() @@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth require(!StateStore.isMaintenanceRunning) } - test("get, put, remove, commit, and all data iterator") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator().isEmpty) - - val store = provider.getStore(0) - assert(!store.hasCommitted) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - - // Verify state after updating - put(store, "a", 1) - assert(store.numKeys() === 1) - intercept[IllegalStateException] { - store.iterator() - } - intercept[IllegalStateException] { - store.updates() - } - assert(provider.latestIterator().isEmpty) - - // Make updates, commit and then verify state - put(store, "b", 2) - put(store, "aa", 3) - assert(store.numKeys() === 3) - remove(store, _.startsWith("a")) - assert(store.numKeys() === 1) - assert(store.commit() === 1) - - assert(store.hasCommitted) - assert(rowsToSet(store.iterator()) === Set("b" -> 2)) - assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) - assert(fileExists(provider, version = 1, isSnapshot = false)) - - assert(getDataFromFiles(provider) === Set("b" -> 2)) - - // Trying to get newer versions should fail - intercept[Exception] { - provider.getStore(2) - } - intercept[Exception] { - getDataFromFiles(provider, 2) - } - - // New updates to the reloaded store with new version, and does not change old version - val reloadedProvider = new HDFSBackedStateStoreProvider( - store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) - val reloadedStore = reloadedProvider.getStore(1) - assert(reloadedStore.numKeys() === 1) - put(reloadedStore, "c", 4) - assert(reloadedStore.numKeys() === 2) - assert(reloadedStore.commit() === 2) - assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) - assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) - assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) - } - - test("filter and concurrent updates") { - val provider = newStoreProvider() - - // Verify state before starting a new set of updates - assert(provider.latestIterator.isEmpty) - val store = provider.getStore(0) - put(store, "a", 1) - put(store, "b", 2) - - // Updates should work while iterating of filtered entries - val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } - filtered.foreach { case (keyRow, valueRow) => - store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) - } - assert(get(store, "a") === Some(2)) - - // Removes should work while iterating of filtered entries - val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } - filtered2.foreach { case (keyRow, _) => - store.remove(keyRow) - } - assert(get(store, "b") === None) - } - - test("updates iterator with all combos of updates and removes") { - val provider = newStoreProvider() - var currentVersion: Int = 0 - - def withStore(body: StateStore => Unit): Unit = { - val store = provider.getStore(currentVersion) - body(store) - currentVersion += 1 - } - - // New data should be seen in updates as value added, even if they had multiple updates - withStore { store => - put(store, "a", 1) - put(store, "aa", 1) - put(store, "aa", 2) - store.commit() - assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) - assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) - } - - // Multiple updates to same key should be collapsed in the updates as a single value update - // Keys that have not been updated should not appear in the updates - withStore { store => - put(store, "a", 4) - put(store, "a", 6) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Keys added, updated and finally removed before commit should not appear in updates - withStore { store => - put(store, "b", 4) // Added, finally removed - put(store, "bb", 5) // Added, updated, finally removed - put(store, "bb", 6) - remove(store, _.startsWith("b")) - store.commit() - assert(updatesToSet(store.updates()) === Set.empty) - assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) - } - - // Removed data should be seen in updates as a key removed - // Removed, but re-added data should be seen in updates as a value update - withStore { store => - remove(store, _.startsWith("a")) - put(store, "a", 10) - store.commit() - assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) - assert(rowsToSet(store.iterator()) === Set("a" -> 10)) - } - } - - test("cancel") { - val provider = newStoreProvider() - val store = provider.getStore(0) - put(store, "a", 1) - store.commit() - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - // cancelUpdates should not change the data in the files - val store1 = provider.getStore(1) - put(store1, "b", 1) - store1.abort() - assert(getDataFromFiles(provider) === Set("a" -> 1)) - } - - test("getStore with unexpected versions") { - val provider = newStoreProvider() - - intercept[IllegalArgumentException] { - provider.getStore(-1) - } - - // Prepare some data in the store - val store = provider.getStore(0) - put(store, "a", 1) - assert(store.commit() === 1) - assert(rowsToSet(store.iterator()) === Set("a" -> 1)) - - intercept[IllegalStateException] { - provider.getStore(2) - } - - // Update store version with some data - val store1 = provider.getStore(1) - put(store1, "b", 1) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) - assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) - } - test("snapshotting") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) var currentVersion = 0 def updateVersionTo(targetVersion: Int): Unit = { @@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } updateVersionTo(2) - require(getDataFromFiles(provider) === Set("a" -> 2)) + require(getData(provider) === Set("a" -> 2)) provider.doMaintenance() // should not generate snapshot files - assert(getDataFromFiles(provider) === Set("a" -> 2)) + assert(getData(provider) === Set("a" -> 2)) for (i <- 1 to currentVersion) { assert(fileExists(provider, i, isSnapshot = false)) // all delta files present @@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // After version 6, snapshotting should generate one snapshot file updateVersionTo(6) - require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + require(getData(provider) === Set("a" -> 6), "store not updated correctly") provider.doMaintenance() // should generate snapshot files val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) assert(snapshotVersion.nonEmpty, "snapshot file not generated") deleteFilesEarlierThanVersion(provider, snapshotVersion.get) assert( - getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), "snapshotting messed up the data of the snapshotted version") assert( - getDataFromFiles(provider) === Set("a" -> 6), + getData(provider) === Set("a" -> 6), "snapshotting messed up the data of the final version") // After version 20, snapshotting should generate newer snapshot files updateVersionTo(20) - require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + require(getData(provider) === Set("a" -> 20), "store not updated correctly") provider.doMaintenance() // do snapshot val latestSnapshotVersion = (0 to 20).filter(version => @@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) - assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data") } test("cleaning") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 20) { val store = provider.getStore(i - 1) @@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted // last couple of versions should be retrievable - assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) - assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + assert(getData(provider, 20) === Set("a" -> 20)) + assert(getData(provider, 19) === Set("a" -> 19)) } test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { @@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") - val provider = newStoreProvider(hadoopConf = conf) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf) provider.getStore(0).commit() provider.getStore(0).commit() @@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("corrupted file handling") { - val provider = newStoreProvider(minDeltasForSnapshot = 5) + val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) put(store, "a", i) @@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) // Corrupt snapshot file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion)) corruptFile(provider, snapshotVersion, isSnapshot = true) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion) + getData(provider, snapshotVersion) } // Corrupt delta file and verify that it throws error - assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) corruptFile(provider, snapshotVersion - 1, isSnapshot = false) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } // Delete delta file and verify that it throws error deleteFilesEarlierThanVersion(provider, snapshotVersion) intercept[Exception] { - getDataFromFiles(provider, snapshotVersion - 1) + getData(provider, snapshotVersion - 1) } } test("StateStore.get") { quietly { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() - // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { - StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) } - // Increase version of the store - val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + // Increase version of the store and try to get again + val store0 = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 1) store0.commit() - assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) - assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + val store1 = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + assert(store1.version === 1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + + // Verify that you can also load older version + val store0reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf) + assert(store0reloaded.version === 0) + assert(rowsToSet(store0reloaded.iterator()) === Set.empty) // Verify that you can remove the store and still reload and use it StateStore.unload(storeId) assert(!StateStore.isLoaded(storeId)) - val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + val store1reloaded = StateStore.get( + storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - put(store1, "a", 2) - assert(store1.commit() === 2) - assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + assert(store1reloaded.version === 1) + put(store1reloaded, "a", 2) + assert(store1reloaded.commit() === 2) + assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2)) } } @@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, opId, 0) val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() - val provider = new HDFSBackedStateStoreProvider( - storeId, keySchema, valueSchema, storeConf, hadoopConf) + val provider = newStoreProvider(storeId) var latestStoreVersion = 0 def generateStoreVersions() { for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + val store = StateStore.get(storeId, keySchema, valueSchema, None, + latestStoreVersion, storeConf, hadoopConf) put(store, "a", i) store.commit() latestStoreVersion += 1 @@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None, + latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } @@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ - val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath + val dir = scheme + "://" + newDir() val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) - val provider = newStoreProvider(dir = dir, hadoopConf = conf) + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf) val store = provider.getStore(0) put(store, "a", 0) val e = intercept[IllegalStateException](store.commit()) @@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } test("SPARK-18416: do not create temp delta file until the store is updated") { - val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = newDir() val storeId = StateStoreId(dir, 0, 0) val storeConf = StateStoreConf.empty val hadoopConf = new Configuration() @@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Getting the store should not create temp file val store0 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf) } // Put should create a temp file @@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Remove should create a temp file val store1 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf) } remove(store1, _ == "a") assert(numTempFiles === 1) @@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Commit without any updates should create a delta file val store2 = shouldNotCreateTempFile { - StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf) + StateStore.get( + storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf) } store2.commit() assert(numTempFiles === 0) assert(numDeltaFiles === 3) } - def getDataFromFiles( - provider: HDFSBackedStateStoreProvider, + override def newStoreProvider(): HDFSBackedStateStoreProvider = { + newStoreProvider(opId = Random.nextInt(), partition = 0) + } + + override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = { + newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation) + } + + override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = { + getData(storeProvider) + } + + override def getData( + provider: HDFSBackedStateStoreProvider, version: Int = -1): Set[(String, Int)] = { - val reloadedProvider = new HDFSBackedStateStoreProvider( - provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedProvider = newStoreProvider(provider.id) if (version < 0) { reloadedProvider.latestIterator().map(rowsToStringInt).toSet } else { - reloadedProvider.iterator(version).map(rowsToStringInt).toSet + reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet } } - def assertMap( - testMapOption: Option[MapType], - expectedMap: Map[String, Int]): Unit = { - assert(testMapOption.nonEmpty, "no map present") - val convertedMap = testMapOption.get.map(rowsToStringInt) - assert(convertedMap === expectedMap) + def newStoreProvider( + opId: Long, + partition: Int, + dir: String = newDir(), + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, + hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val provider = new HDFSBackedStateStoreProvider() + provider.init( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + indexOrdinal = None, + new StateStoreConf(sqlConf), + hadoopConf) + provider } def fileExists( @@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth filePath.delete() filePath.createNewFile() } +} - def storeLoaded(storeId: StateStoreId): Boolean = { - val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) - val loadedStores = StateStore invokePrivate method() - loadedStores.contains(storeId) - } +abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] + extends SparkFunSuite { + import StateStoreTestsHelper._ - def unloadStore(storeId: StateStoreId): Boolean = { - val method = PrivateMethod('remove) - StateStore invokePrivate method(storeId) - } + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() - def newStoreProvider( - opId: Long = Random.nextLong, - partition: Int = 0, - minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get, - dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString, - hadoopConf: Configuration = new Configuration() - ): HDFSBackedStateStoreProvider = { - val sqlConf = new SQLConf() - sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) - sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) - new HDFSBackedStateStoreProvider( - StateStoreId(dir, opId, partition), - keySchema, - valueSchema, - new StateStoreConf(sqlConf), - hadoopConf) + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + assert(get(store, "a") === None) + assert(store.iterator().isEmpty) + assert(store.numKeys() === 0) + + // Verify state after updating + put(store, "a", 1) + assert(get(store, "a") === Some(1)) + assert(store.numKeys() === 1) + + assert(store.iterator().nonEmpty) + assert(getLatestData(provider).isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + assert(store.numKeys() === 3) + remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(getLatestData(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getData(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = newStoreProvider(store.id) + val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) + put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4)) + assert(getData(provider, version = 1) === Set("b" -> 2)) } - def remove(store: StateStore, condition: String => Boolean): Unit = { - store.remove(row => condition(rowToString(row))) + test("removing while iterating") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" } + filtered.foreach { tuple => + store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" } + filtered2.foreach { tuple => store.remove(tuple.key) } + assert(get(store, "b") === None) } - private def put(store: StateStore, key: String, value: Int): Unit = { - store.put(stringToRow(key), intToRow(value)) + test("abort") { + val provider = newStoreProvider() + val store = provider.getStore(0) + put(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 1) + store1.abort() } - private def get(store: StateStore, key: String): Option[Int] = { - store.get(stringToRow(key)).map(rowToInt) + test("getStore with invalid versions") { + val provider = newStoreProvider() + + def checkInvalidVersion(version: Int): Unit = { + intercept[Exception] { + provider.getStore(version) + } + } + + checkInvalidVersion(-1) + checkInvalidVersion(1) + + val store = provider.getStore(0) + put(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + val store1_ = provider.getStore(1) + assert(rowsToSet(store1_.iterator()) === Set("a" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(2) + + // Update store version with some data + val store1 = provider.getStore(1) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1)) + put(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + + checkInvalidVersion(-1) + checkInvalidVersion(3) } -} -private[state] object StateStoreSuite { + /** Return a new provider with a random id */ + def newStoreProvider(): ProviderClass + + /** Return a new provider with the given id */ + def newStoreProvider(storeId: StateStoreId): ProviderClass + + /** Get the latest data referred to by the given provider but not using this provider */ + def getLatestData(storeProvider: ProviderClass): Set[(String, Int)] + + /** + * Get a specific version of data referred to by the given provider but not using + * this provider + */ + def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] +} - /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ - trait TestUpdate - case class Added(key: String, value: Int) extends TestUpdate - case class Updated(key: String, value: Int) extends TestUpdate - case class Removed(key: String) extends TestUpdate +object StateStoreTestsHelper { val strProj = UnsafeProjection.create(Array[DataType](StringType)) val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) @@ -692,26 +649,29 @@ private[state] object StateStoreSuite { row.getInt(0) } - def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { - (rowToInt(row._1), rowToInt(row._2)) + def rowsToStringInt(row: UnsafeRowPair): (String, Int) = { + (rowToString(row.key), rowToInt(row.value)) } + def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } - def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { - (rowToString(row._1), rowToInt(row._2)) + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.getRange(None, None).foreach { rowPair => + if (condition(rowToString(rowPair.key))) store.remove(rowPair.key) + } } - def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { - iterator.map(rowsToStringInt).toSet + def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) } - def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { - iterator.map { - case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) - case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) - case ValueRemoved(key, _) => Removed(rowToString(key)) - }.toSet + def get(store: StateStore, key: String): Option[Int] = { + Option(store.get(stringToRow(key))).map(rowToInt) } + + def newDir(): String = Utils.createTempDir().toString } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 6bb9408ce99ed..0d9ca81349be5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = 5000) // timestamp should change - test("StateStoreUpdater - rows are cloned before writing to StateStore") { - // function for running count - val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { - state.update(state.getOption.getOrElse(0) + values.size) - Iterator.empty - } - val store = newStateStore() - val plan = newFlatMapGroupsWithStateExec(func) - val updater = new plan.StateStoreUpdater(store) - val data = Seq(1, 1, 2) - val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) - returnIter.size // consume the iterator to force store updates - val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet - assert(storeData === Set((1, 2), (2, 1))) - } - test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf callFunction() val updatedStateRow = store.get(key) assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (updatedStateRow != null) { assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, "final timeout timestamp not as expected") } } @@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite { import scala.collection.JavaConverters._ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } } - override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { - iterator.filter { case (k, v) => c(k, v) } + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { + map.put(key.copy(), newValue.copy()) } - - override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def remove(condition: (UnsafeRow) => Boolean): Unit = { - iterator.map(_._1).filter(condition).foreach(map.remove) - } override def commit(): Long = version + 1 override def abort(): Unit = { } override def id: StateStoreId = null override def version: Long = 0 - override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } override def numKeys(): Long = map.size override def hasCommitted: Boolean = true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 1fc062974e185..280f2dc27b4a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -31,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider @@ -614,6 +616,30 @@ class StreamSuite extends StreamTest { assertDescContainsQueryNameAnd(batch = 2) query.stop() } + + testQuietly("specify custom state store provider") { + val queryName = "memStream" + val providerClassName = classOf[TestStateStoreProvider].getCanonicalName + withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) { + val input = MemoryStream[Int] + val query = input + .toDS() + .groupBy() + .count() + .writeStream + .outputMode("complete") + .format("memory") + .queryName(queryName) + .start() + input.addData(1, 2, 3) + val e = intercept[Exception] { + query.awaitTermination() + } + + assert(e.getMessage.contains(providerClassName)) + assert(e.getMessage.contains("instantiated")) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -719,3 +745,22 @@ object ThrowingInterruptedIOException { */ @volatile var createSourceLatch: CountDownLatch = null } + +class TestStateStoreProvider extends StateStoreProvider { + + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + indexOrdinal: Option[Int], + storeConfs: StateStoreConf, + hadoopConf: Configuration): Unit = { + throw new Exception("Successfully instantiated") + } + + override def id: StateStoreId = null + + override def close(): Unit = { } + + override def getStore(version: Long): StateStore = null +} From 10e526e7e63bbf19e464bde2f6c4e581cf6c7c45 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 May 2017 20:12:32 -0700 Subject: [PATCH 044/133] [SPARK-20213][SQL] Fix DataFrameWriter operations in SQL UI tab ## What changes were proposed in this pull request? Currently the `DataFrameWriter` operations have several problems: 1. non-file-format data source writing action doesn't show up in the SQL tab in Spark UI 2. file-format data source writing action shows a scan node in the SQL tab, without saying anything about writing. (streaming also have this issue, but not fixed in this PR) 3. Spark SQL CLI actions don't show up in the SQL tab. This PR fixes all of them, by refactoring the `ExecuteCommandExec` to make it have children. close https://github.com/apache/spark/pull/17540 ## How was this patch tested? existing tests. Also test the UI manually. For a simple command: `Seq(1 -> "a").toDF("i", "j").write.parquet("/tmp/qwe")` before this PR: qq20170523-035840 2x after this PR: qq20170523-035708 2x Author: Wenchen Fan Closes #18064 from cloud-fan/execution. --- .../spark/sql/kafka010/KafkaWriter.scala | 10 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/Command.scala | 3 +- .../catalyst/plans/logical/Statistics.scala | 15 ++- .../apache/spark/sql/DataFrameWriter.scala | 12 +-- .../scala/org/apache/spark/sql/Dataset.scala | 48 +++++++-- .../spark/sql/execution/QueryExecution.scala | 7 +- .../spark/sql/execution/SQLExecution.scala | 13 +++ .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../command/AnalyzeColumnCommand.scala | 7 +- .../command/AnalyzeTableCommand.scala | 2 +- .../spark/sql/execution/command/cache.scala | 10 +- .../sql/execution/command/commands.scala | 24 +++-- .../command/createDataSourceTables.scala | 4 +- .../spark/sql/execution/command/views.scala | 4 +- .../execution/datasources/DataSource.scala | 61 +++++------- .../datasources/FileFormatWriter.scala | 98 +++++++++---------- .../InsertIntoDataSourceCommand.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 10 +- .../SaveIntoDataSourceCommand.scala | 13 +-- .../datasources/csv/CSVDataSource.scala | 3 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 14 +-- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../execution/streaming/StreamExecution.scala | 6 +- .../sql/execution/streaming/console.scala | 4 +- .../sql/execution/streaming/memory.scala | 4 +- .../execution/metric/SQLMetricsSuite.scala | 5 +- .../sql/util/DataFrameCallbackSuite.scala | 27 ++--- .../hive/thriftserver/SparkSQLDriver.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 11 ++- .../apache/spark/sql/hive/test/TestHive.scala | 54 +++++----- .../hive/execution/HiveComparisonTest.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 20 ++-- 37 files changed, 299 insertions(+), 218 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 0ed9d4e84d54d..5e9ae35b3f008 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -85,12 +85,10 @@ private[kafka010] object KafkaWriter extends Logging { topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - queryExecution.toRdd.foreachPartition { iter => - val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) - Utils.tryWithSafeFinally(block = writeTask.execute(iter))( - finallyBlock = writeTask.close()) - } + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d3f822bf7eb0e..5ba043e17a128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -357,7 +357,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override protected def innerChildren: Seq[QueryPlan[_]] = subqueries + override def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6f55..ec5766e1f67f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index a64562b5dbd93..ae5f1d1fc4f83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -22,7 +22,8 @@ import java.math.{MathContext, RoundingMode} import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -243,9 +244,9 @@ object ColumnStat extends Logging { } col.dataType match { - case _: IntegralType => fixedLenTypeStruct(LongType) + case dt: IntegralType => fixedLenTypeStruct(dt) case _: DecimalType => fixedLenTypeStruct(col.dataType) - case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case dt @ (DoubleType | FloatType) => fixedLenTypeStruct(dt) case BooleanType => fixedLenTypeStruct(col.dataType) case DateType => fixedLenTypeStruct(col.dataType) case TimestampType => fixedLenTypeStruct(col.dataType) @@ -264,14 +265,12 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { + def rowToColumnStat(row: InternalRow, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), // for string/binary min/max, get should return null - min = Option(row.get(1)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), - max = Option(row.get(2)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b71c5eb843eec..255c4064eb574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand} import org.apache.spark.sql.sources.BaseRelation @@ -231,12 +232,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") runCommand(df.sparkSession, "save") { - SaveIntoDataSourceCommand( - query = df.logicalPlan, - provider = source, + DataSource( + sparkSession = df.sparkSession, + className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap, - mode = mode) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -607,7 +607,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { try { val start = System.nanoTime() // call `QueryExecution.toRDD` to trigger the execution of commands. - qe.toRdd + SQLExecution.withNewExecutionId(session, qe)(qe.toRdd) val end = System.nanoTime() session.listenerManager.onSuccess(name, qe, end - start) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1cd6fda5edc87..d5b4c82c3558b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -179,9 +179,9 @@ class Dataset[T] private[sql]( // to happen right away to let these side effects take place eagerly. queryExecution.analyzed match { case c: Command => - LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => - LocalRelation(u.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) case _ => queryExecution.analyzed } @@ -248,8 +248,13 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + showString(takeResult, numRows, truncate, vertical) + } + + private def showString( + dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { + val hasMoreData = dataWithOneMoreRow.length > numRows + val data = dataWithOneMoreRow.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -684,6 +689,18 @@ class Dataset[T] private[sql]( } else { println(showString(numRows, truncate = 0)) } + + // An internal version of `show`, which won't set execution id and trigger listeners. + private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { + val numRows = _numRows.max(0) + val takeResult = toDF().takeInternal(numRows + 1) + + if (truncate) { + println(showString(takeResult, numRows, truncate = 20, vertical = false)) + } else { + println(showString(takeResult, numRows, truncate = 0, vertical = false)) + } + } // scalastyle:on println /** @@ -2453,6 +2470,11 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) + // An internal version of `take`, which won't set execution id and trigger listeners. + private[sql] def takeInternal(n: Int): Array[T] = { + collectFromPlan(limit(n).queryExecution.executedPlan) + } + /** * Returns the first `n` rows in the Dataset as a list. * @@ -2477,6 +2499,11 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) + // An internal version of `collect`, which won't set execution id and trigger listeners. + private[sql] def collectInternal(): Array[T] = { + collectFromPlan(queryExecution.executedPlan) + } + /** * Returns a Java list that contains all rows in this Dataset. * @@ -2518,6 +2545,11 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } + // An internal version of `count`, which won't set execution id and trigger listeners. + private[sql] def countInternal(): Long = { + groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) + } + /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2763,7 +2795,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private def createTempViewCommand( + private[spark] def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { @@ -2954,17 +2986,17 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a DataFrame. */ - @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + @inline private def withPlan(logicalPlan: LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withTypedPlan[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } /** A convenient function to wrap a set based logical plan and produce a Dataset. */ - @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { // Set operators widen types (change the schema), so we cannot reuse the row encoder. Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 2e05e5d65923c..1ba9a79446aad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -113,10 +113,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** - * Returns the result as a hive compatible sequence of strings. This is for testing only. + * Returns the result as a hive compatible sequence of strings. This is used in tests and + * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => + case ExecutedCommandExec(desc: DescribeTableCommand, _) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -127,7 +128,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index be35916e3447e..bb206e84325fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -39,6 +39,19 @@ object SQLExecution { executionIdToQueryExecution.get(executionId) } + private val testing = sys.props.contains("spark.testing") + + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + // only throw an exception during tests. a missing execution ID should not fail a job. + if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + // Attention testers: when a test fails with this exception, it means that the action that + // started execution of a query didn't call withNewExecutionId. The execution ID should be + // set by calling withNewExecutionId in the action that begins execution, like + // Dataset.collect or DataFrameWriter.insertInto. + throw new IllegalStateException("Execution ID should be set") + } + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 843ce63161220..f13294c925e36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,7 +346,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3486a6bce8180..456a8f3b20f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -64,7 +64,7 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 7063b08f7c644..1d601374de135 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 0d8db2ff5d5a0..2de14c90ec757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableTyp import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution /** @@ -96,11 +97,13 @@ case class AnalyzeColumnCommand( attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() + val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .executedPlan.executeTake(1).head val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) + // according to `ColumnStat.statExprs`, the stats struct always have 6 fields. + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 6), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d2ea0cdf61aa6..3183c7911b1fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -56,7 +56,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).count() + val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 336f14dd97aea..184d0387ebfa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,19 +30,19 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override protected def innerChildren: Seq[QueryPlan[_]] = { - plan.toSeq - } + override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + Dataset.ofRows(sparkSession, logicalPlan) + .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) + .run(sparkSession) } sparkSession.catalog.cacheTable(tableIdent.quotedString) if (!isLazy) { // Performs eager caching - sparkSession.table(tableIdent).count() + sparkSession.table(tableIdent).countInternal() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 41d91d877d4c2..99d81c49f1e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ @@ -36,14 +35,20 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ trait RunnableCommand extends logical.Command { - def run(sparkSession: SparkSession): Seq[Row] + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + throw new NotImplementedError + } + + def run(sparkSession: SparkSession): Seq[Row] = { + throw new NotImplementedError + } } /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -55,14 +60,19 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) + val rows = if (children.isEmpty) { + cmd.run(sqlContext.sparkSession) + } else { + cmd.run(sqlContext.sparkSession, children) + } + rows.map(converter(_).asInstanceOf[InternalRow]) } - override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil + override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren override def output: Seq[Attribute] = cmd.output - override def children: Seq[SparkPlan] = Nil + override def nodeName: String = cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2d890118ae0a5..729bd39d821c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -122,7 +122,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + override def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) @@ -195,7 +195,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, Dataset.ofRows(session, query)) + dataSource.writeAndRead(mode, query) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 00f0acab21aa2..1945d68241343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -97,7 +97,7 @@ case class CreateViewCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -264,7 +264,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 9fce29b06b9d8..958715eefa0a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -28,8 +28,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -388,9 +389,10 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out in this [[FileFormat]]. + * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. */ - private def writeInFileFormat(format: FileFormat, mode: SaveMode, data: DataFrame): Unit = { + private def planForWritingFileFormat( + format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -408,16 +410,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan - plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -426,36 +418,35 @@ case class DataSource( // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. - val plan = - InsertIntoHadoopFsRelationCommand( - outputPath = outputPath, - staticPartitions = Map.empty, - ifPartitionNotExists = false, - partitionColumns = partitionAttributes, - bucketSpec = bucketSpec, - fileFormat = format, - options = options, - query = data.logicalPlan, - mode = mode, - catalogTable = catalogTable, - fileIndex = fileIndex) - sparkSession.sessionState.executePlan(plan).toRdd + InsertIntoHadoopFsRelationCommand( + outputPath = outputPath, + staticPartitions = Map.empty, + ifPartitionNotExists = false, + partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + bucketSpec = bucketSpec, + fileFormat = format, + options = options, + query = data, + mode = mode, + catalogTable = catalogTable, + fileIndex = fileIndex) } /** - * Writes the given [[DataFrame]] out to this [[DataSource]] and returns a [[BaseRelation]] for + * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. */ - def writeAndRead(mode: SaveMode, data: DataFrame): BaseRelation = { + def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + dataSource.createRelation( + sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - writeInFileFormat(format, mode, data) + sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => @@ -464,18 +455,18 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out to this [[DataSource]]. + * Returns a logical plan to write the given [[LogicalPlan]] out to this [[DataSource]]. */ - def write(mode: SaveMode, data: DataFrame): Unit = { + def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => - writeInFileFormat(format, mode, data) + planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4ec09bff429c5..afe454f714c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -96,7 +96,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - queryExecution: QueryExecution, + plan: SparkPlan, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -111,9 +111,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.logical.output + val allColumns = plan.output val partitionSet = AttributeSet(partitionColumns) - val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val dataColumns = allColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -151,7 +151,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val actualOrdering = plan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -161,50 +161,50 @@ object FileFormatWriter extends Logging { } } - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - - try { - val rdd = if (orderingMatched) { - queryExecution.toRdd - } else { - SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), - global = false, - child = queryExecution.executedPlan).execute() - } - val ret = new Array[WriteTaskResult](rdd.partitions.length) - sparkSession.sparkContext.runJob( - rdd, - (taskContext: TaskContext, iter: Iterator[InternalRow]) => { - executeTask( - description = description, - sparkStageId = taskContext.stageId(), - sparkPartitionId = taskContext.partitionId(), - sparkAttemptNumber = taskContext.attemptNumber(), - committer, - iterator = iter) - }, - 0 until rdd.partitions.length, - (index, res: WriteTaskResult) => { - committer.onTaskCommit(res.commitMsg) - ret(index) = res - }) - - val commitMsgs = ret.map(_.commitMsg) - val updatedPartitions = ret.flatMap(_.updatedPartitions) - .distinct.map(PartitioningUtils.parsePathFragment) - - committer.commitJob(job, commitMsgs) - logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions) - } catch { case cause: Throwable => - logError(s"Aborting job ${job.getJobID}.", cause) - committer.abortJob(job) - throw new SparkException("Job aborted.", cause) + SQLExecution.checkSQLExecutionId(sparkSession) + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + + try { + val rdd = if (orderingMatched) { + plan.execute() + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = plan).execute() } + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, + (taskContext: TaskContext, iter: Iterator[InternalRow]) => { + executeTask( + description = description, + sparkStageId = taskContext.stageId(), + sparkPartitionId = taskContext.partitionId(), + sparkAttemptNumber = taskContext.attemptNumber(), + committer, + iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res + }) + + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) + + committer.commitJob(job, commitMsgs) + logInfo(s"Job ${job.getJobID} committed.") + refreshFunction(updatedPartitions) + } catch { case cause: Throwable => + logError(s"Aborting job ${job.getJobID}.", cause) + committer.abortJob(job) + throw new SparkException("Job aborted.", cause) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..08b2f4f31170f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index c9d31449d3629..00aa1240886e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ /** @@ -53,12 +54,13 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) extends RunnableCommand { - import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil + + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) - override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { @@ -144,7 +146,7 @@ case class InsertIntoHadoopFsRelationCommand( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 6f19ea195c0cd..5eb6a8471be0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.sources.CreatableRelationProvider /** * Saves the results of `query` in to a data source. @@ -33,19 +34,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand */ case class SaveIntoDataSourceCommand( query: LogicalPlan, - provider: String, - partitionColumns: Seq[String], + dataSource: CreatableRelationProvider, options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { - DataSource( - sparkSession, - className = provider, - partitionColumns = partitionColumns, - options = options).write(mode, Dataset.ofRows(sparkSession, query)) + dataSource.createRelation( + sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 83bdf6fe224be..76f121c0c955f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -144,7 +144,8 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + val maybeFirstLine = + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index f8d4a9bb5b81a..fdc5e85f3c2ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,7 +76,7 @@ case class CreateTempViewUsing( CatalogUtils.maskCredentials(options) } - def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8b45dba04d29e..a06f1ce3287e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -129,12 +129,14 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val url = jdbcOptions.url - val table = jdbcOptions.table - val properties = jdbcOptions.asProperties - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) + import scala.collection.JavaConverters._ + + val options = jdbcOptions.asProperties.asScala + + ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + + new JdbcRelationProvider().createRelation( + data.sparkSession.sqlContext, mode, options.toMap, data) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71eaab119d75d..ca61c2efe2ddf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -788,7 +788,7 @@ object JdbcUtils extends Logging { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.foreachPartition(iterator => savePartition( + repartitionedDF.rdd.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6885d0bf67ccb..96225ecffad48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -122,7 +122,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = data.queryExecution, + plan = data.queryExecution.executedPlan, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b6ddf7437ea13..ab8608563c4fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -655,7 +655,9 @@ class StreamExecution( new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { - sink.addBatch(currentBatchId, nextBatch) + SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + sink.addBatch(currentBatchId, nextBatch) + } } awaitBatchLock.lock() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index e8b9712d19cd5..38c63191106d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -46,8 +46,8 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println("-------------------------------------------") // scalastyle:off println data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) + data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) + .showInternal(numRowsToShow, isTruncated) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 971ce5afb1778..7eaa803a9ecb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -196,11 +196,11 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches += rows } case Complete => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches.clear() batches += rows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index e544245588f46..a4e62f1d16792 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -290,10 +290,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => + // person creates a temporary view. get the DF before listing previous execution IDs + val data = person.select('name) + sparkContext.listenerBus.waitUntilEmpty(10000) val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) + data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7c9ea7d393630..a239e39d9c5a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{functions, AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} -import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -178,26 +179,28 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.range(10).write.format("json").save(path.getCanonicalPath) assert(commands.length == 1) assert(commands.head._1 == "save") - assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand]) - assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json") + assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .fileFormat.isInstanceOf[JsonFileFormat]) } withTable("tab") { - sql("CREATE TABLE tab(i long) using parquet") + sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") - assert(commands.length == 2) - assert(commands(1)._1 == "insertInto") - assert(commands(1)._2.isInstanceOf[InsertIntoTable]) - assert(commands(1)._2.asInstanceOf[InsertIntoTable].table + assert(commands.length == 3) + assert(commands(2)._1 == "insertInto") + assert(commands(2)._2.isInstanceOf[InsertIntoTable]) + assert(commands(2)._2.asInstanceOf[InsertIntoTable].table .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab") } + // exiting withTable adds commands(3) via onSuccess (drops tab) withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") - assert(commands.length == 3) - assert(commands(2)._1 == "saveAsTable") - assert(commands(2)._2.isInstanceOf[CreateTable]) - assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + assert(commands.length == 5) + assert(commands(4)._1 == "saveAsTable") + assert(commands(4)._2.isInstanceOf[CreateTable]) + assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) } withTable("tab") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0d5dc7af5f522..6775902173444 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -60,7 +60,9 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.hiveResultString() + hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + execution.hiveResultString() + } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 10ce8e3730a0d..392b7cfaa8eff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -32,10 +32,11 @@ import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ @@ -81,7 +82,7 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil var createdTempDir: Option[Path] = None @@ -230,7 +231,9 @@ case class InsertIntoHiveTable( * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) + val sessionState = sparkSession.sessionState val externalCatalog = sparkSession.sharedState.externalCatalog val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version @@ -344,7 +347,7 @@ case class InsertIntoHiveTable( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index e1534c797d55b..4e1792321c89b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -34,8 +34,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient @@ -294,23 +294,23 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - sql( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("srcpart1", () => { - sql( - "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("src_thrift", () => { @@ -318,8 +318,7 @@ private[hive] class TestHiveSparkSession( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - sql( - s""" + s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( @@ -329,13 +328,12 @@ private[hive] class TestHiveSparkSession( |STORED AS |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' - """.stripMargin) + """.stripMargin.cmd.apply() - sql( - s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' - |INTO TABLE src_thrift - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |INTO TABLE src_thrift + """.stripMargin.cmd.apply() }), TestTable("serdeins", s"""CREATE TABLE serdeins (key INT, value STRING) @@ -458,7 +456,17 @@ private[hive] class TestHiveSparkSession( logDebug(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) + + // test tables are loaded lazily, so they may be loaded in the middle a query execution which + // has already set the execution id. + if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { + // We don't actually have a `QueryExecution` here, use a fake one instead. + SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation)) { + createCmds.foreach(_()) + } + } else { + createCmds.foreach(_()) + } if (cacheTables) { new SQLContext(self).cacheTable(name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 98aa92a9bb88f..cee82cda4628a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -341,7 +342,10 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) - try { (query, prepareAnswer(query, query.hiveResultString())) } catch { + def getResult(): Seq[String] = { + SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + } + try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c944f28d10ef4..da7a0645dbbeb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -965,14 +965,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("sanity test for SPARK-6618") { - (1 to 100).par.map { i => - val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - sessionState.catalog.lookupRelation(TableIdentifier(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") + val threads: Seq[Thread] = (1 to 10).map { i => + new Thread("test-thread-" + i) { + override def run(): Unit = { + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + sessionState.catalog.lookupRelation(TableIdentifier(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } + threads.foreach(_.start()) + threads.foreach(_.join(10000)) } test("SPARK-5203 union with different decimal precision") { From 52ed9b289d169219f7257795cbedc56565a39c71 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 30 May 2017 20:24:43 -0700 Subject: [PATCH 045/133] [SPARK-20275][UI] Do not display "Completed" column for in-progress applications ## What changes were proposed in this pull request? Current HistoryServer will display completed date of in-progress application as `1969-12-31 23:59:59`, which is not so meaningful. Instead of unnecessarily showing this incorrect completed date, here propose to make this column invisible for in-progress applications. The purpose of only making this column invisible rather than deleting this field is that: this data is fetched through REST API, and in the REST API the format is like below shows, in which `endTime` matches `endTimeEpoch`. So instead of changing REST API to break backward compatibility, here choosing a simple solution to only make this column invisible. ``` [ { "id" : "local-1491805439678", "name" : "Spark shell", "attempts" : [ { "startTime" : "2017-04-10T06:23:57.574GMT", "endTime" : "1969-12-31T23:59:59.999GMT", "lastUpdated" : "2017-04-10T06:23:57.574GMT", "duration" : 0, "sparkUser" : "", "completed" : false, "startTimeEpoch" : 1491805437574, "endTimeEpoch" : -1, "lastUpdatedEpoch" : 1491805437574 } ] } ]% ``` Here is UI before changed: screen shot 2017-04-10 at 3 45 57 pm And after: screen shot 2017-04-10 at 4 02 35 pm ## How was this patch tested? Manual verification. Author: jerryshao Closes #17588 from jerryshao/SPARK-20275. --- .../org/apache/spark/ui/static/historypage-template.html | 4 ++-- .../resources/org/apache/spark/ui/static/historypage.js | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 6ba3b092dc658..c2afa993b2f20 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -39,7 +39,7 @@ Started - + Completed @@ -73,7 +73,7 @@ {{#attempts}} {{attemptId}} {{startTime}} - {{endTime}} + {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 1f89306403cd5..7db8c27e8f7c9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -177,6 +177,13 @@ $(document).ready(function() { } } + if (requestedIncomplete) { + var completedCells = document.getElementsByClassName("completedColumn"); + for (i = 0; i < completedCells.length; i++) { + completedCells[i].style.display='none'; + } + } + var durationCells = document.getElementsByClassName("durationClass"); for (i = 0; i < durationCells.length; i++) { var timeInMilliseconds = parseInt(durationCells[i].title); From 1f5dddffa3f065dff2b0a6b0fe7e463edfa4a5f1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 May 2017 21:14:55 -0700 Subject: [PATCH 046/133] Revert "[SPARK-20392][SQL] Set barrier to prevent re-entering a tree" This reverts commit 8ce0d8ffb68bd9e89c23d3a026308dcc039a1b1d. --- .../sql/catalyst/analysis/Analyzer.scala | 75 ++++++--------- .../catalyst/analysis/DecimalPrecision.scala | 2 +- .../ResolveTableValuedFunctions.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++--- .../catalyst/analysis/timeZoneAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/view.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 35 +++++++ .../plans/logical/basicLogicalOperators.scala | 9 -- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 --- .../sql/catalyst/plans/LogicalPlanSuite.scala | 26 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 92 +++++++++---------- .../execution/datasources/DataSource.scala | 2 +- .../sql/execution/datasources/rules.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- 16 files changed, 144 insertions(+), 151 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8818404094eb1..29183fd131292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -166,15 +166,14 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases, - EliminateBarriers) + CleanupAliases) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -202,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -244,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -634,7 +633,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -689,9 +688,7 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { - // Remove analysis barrier if any. - val right = EliminateBarriers(oriRight) + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") @@ -734,7 +731,7 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - oriRight + right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -747,7 +744,7 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - AnalysisBarrier(newRight) + newRight } } @@ -808,7 +805,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -982,7 +979,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -1038,7 +1035,7 @@ class Analyzer( }} } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case agg @ Aggregate(groups, aggs, child) if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && groups.exists(!_.resolved) => @@ -1062,13 +1059,11 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions - case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => - val child = EliminateBarriers(orgChild) + case s @ Sort(order, _, child) if !s.resolved && child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -1089,8 +1084,7 @@ class Analyzer( case ae: AnalysisException => s } - case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => - val child = EliminateBarriers(orgChild) + case f @ Filter(cond, child) if !f.resolved && child.resolved => try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) @@ -1117,7 +1111,7 @@ class Analyzer( */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { - return AnalysisBarrier(plan) + return plan } plan match { case p: Project => @@ -1189,7 +1183,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1528,7 +1522,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1543,7 +1537,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1569,9 +1563,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1631,8 +1623,6 @@ class Analyzer( case ae: AnalysisException => filter } - case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => - apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. @@ -1745,7 +1735,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1803,7 +1793,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2120,7 +2110,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2165,7 +2155,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2230,7 +2220,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2295,7 +2285,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2381,7 +2371,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2415,7 +2405,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2469,7 +2459,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2498,13 +2488,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** Remove the barrier nodes of analysis */ -object EliminateBarriers extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case AnalysisBarrier(child) => child - } -} - /** * Ignore event time watermark in batch query, which is only supported in Structured Streaming. * TODO: add this rule into analyzer rule list. @@ -2554,7 +2537,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index ac72bc4ef4200..9c38dd2ee4e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressions( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index a214e59302cd9..7358f9ee36921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { }) ) - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c3645170589c8..e1dd010d37a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -206,7 +206,7 @@ object TypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -261,7 +261,7 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p case s @ SetOperation(left, right) if s.childrenResolved && @@ -335,7 +335,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -391,7 +391,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -449,7 +449,7 @@ object TypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -490,7 +490,7 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -580,7 +580,7 @@ object TypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -602,7 +602,7 @@ object TypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -632,7 +632,7 @@ object TypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -656,7 +656,7 @@ object TypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -673,7 +673,7 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index af1f9165b0044..a27aa845bf0ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { } override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformAllExpressions(transformTimeZoneExprs) + plan.resolveExpressions(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bbe41cf8f15e..ea46dd7282401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf * completely resolved during the batch of Resolution. */ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver val queryColumnNames = desc.viewQueryColumnNames diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 46d1aac1857d7..2a3e07aebe709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -236,7 +236,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 23520eb82b043..2ebb2ff323c6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -46,6 +46,41 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Returns true if this subtree contains any streaming data sources. */ def isStreaming: Boolean = children.exists(_.isStreaming == true) + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already + * been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRuleOnChildren = mapChildren(_.resolveOperators(rule)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } else { + this + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + this resolveOperators { + case p => p.transformExpressions(r) + } + } + /** A cache for the estimated statistics, such that it will only be computed once. */ private var statsCache: Option[Statistics] = None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b8c2f7670d7b1..6878b6b179c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -897,11 +896,3 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } - -/** A logical plan for setting a barrier of analysis */ -case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { - override def output: Seq[Attribute] = child.output - override def analyzed: Boolean = true - override def isStreaming: Boolean = child.isStreaming - override lazy val canonicalized: LogicalPlan = child.canonicalized -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5393786891e07..be26b1b26f175 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -441,20 +441,6 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } - test("analysis barrier") { - // [[AnalysisBarrier]] will be removed after analysis - checkAnalysis( - Project(Seq(UnresolvedAttribute("tbl.a")), - AnalysisBarrier(SubqueryAlias("tbl", testRelation))), - Project(testRelation.output, SubqueryAlias("tbl", testRelation))) - - // Verify we won't go through a plan wrapped in a barrier. - // Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved. - val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")), - SubqueryAlias("tbl", testRelation))) - assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) - } - test("SPARK-20311 range(N) as alias") { def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 215db848383eb..cc86f1f6e2f48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType /** - * This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure - * it can correctly skip sub-trees that have already been marked as analyzed. + * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly + * skips sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 @@ -36,35 +36,37 @@ class LogicalPlanSuite extends SparkFunSuite { private val testRelation = LocalRelation() - test("transformUp runs on operators") { + test("resolveOperator runs on operators") { invocationCount = 0 val plan = Project(Nil, testRelation) - plan transformUp function + plan resolveOperators function assert(invocationCount === 1) } - test("transformUp runs on operators recursively") { + test("resolveOperator runs on operators recursively") { invocationCount = 0 val plan = Project(Nil, Project(Nil, testRelation)) - plan transformUp function + plan resolveOperators function assert(invocationCount === 2) } - test("transformUp skips all ready resolved plans wrapped in analysis barrier") { + test("resolveOperator skips all ready resolved plans") { invocationCount = 0 - val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) - plan transformUp function + val plan = Project(Nil, Project(Nil, testRelation)) + plan.foreach(_.setAnalyzed()) + plan resolveOperators function assert(invocationCount === 0) } - test("transformUp skips partially resolved plans wrapped in analysis barrier") { + test("resolveOperator skips partially resolved plans") { invocationCount = 0 - val plan1 = AnalysisBarrier(Project(Nil, testRelation)) + val plan1 = Project(Nil, testRelation) val plan2 = Project(Nil, plan1) - plan2 transformUp function + plan1.foreach(_.setAnalyzed()) + plan2 resolveOperators function assert(invocationCount === 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d5b4c82c3558b..5ffe32f61ee09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -187,9 +187,6 @@ class Dataset[T] private[sql]( } } - // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) - /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use @@ -421,7 +418,7 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -624,7 +621,7 @@ class Dataset[T] private[sql]( require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -810,7 +807,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def join(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } /** @@ -888,7 +885,7 @@ class Dataset[T] private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sparkSession.sessionState.executePlan( - Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] withPlan { @@ -949,7 +946,7 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. @@ -958,8 +955,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed - val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -991,7 +988,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) + Join(logicalPlan, right.logicalPlan, joinType = Cross, None) } /** @@ -1023,8 +1020,8 @@ class Dataset[T] private[sql]( // etc. val joined = sparkSession.sessionState.executePlan( Join( - this.planWithBarrier, - other.planWithBarrier, + this.logicalPlan, + other.logicalPlan, JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] @@ -1194,7 +1191,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - UnresolvedHint(name, parameters, planWithBarrier) + UnresolvedHint(name, parameters, logicalPlan) } /** @@ -1220,7 +1217,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, planWithBarrier) + SubqueryAlias(alias, logicalPlan) } /** @@ -1258,7 +1255,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - Project(cols.map(_.named), planWithBarrier) + Project(cols.map(_.named), logicalPlan) } /** @@ -1313,8 +1310,8 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, - planWithBarrier) + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, + logicalPlan) if (encoder.flat) { new Dataset[U1](sparkSession, project, encoder) @@ -1332,8 +1329,8 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) - val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) + val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -1409,7 +1406,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, planWithBarrier) + Filter(condition.expr, logicalPlan) } /** @@ -1586,7 +1583,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = planWithBarrier + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1732,7 +1729,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), planWithBarrier) + Limit(Literal(n), logicalPlan) } /** @@ -1761,7 +1758,7 @@ class Dataset[T] private[sql]( def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) + CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** @@ -1775,7 +1772,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { - Intersect(planWithBarrier, other.planWithBarrier) + Intersect(logicalPlan, other.logicalPlan) } /** @@ -1789,7 +1786,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withSetOperator { - Except(planWithBarrier, other.planWithBarrier) + Except(logicalPlan, other.logicalPlan) } /** @@ -1810,7 +1807,7 @@ class Dataset[T] private[sql]( s"Fraction must be nonnegative, but got ${fraction}") withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } } @@ -1852,15 +1849,15 @@ class Dataset[T] private[sql]( // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. - val sortOrder = planWithBarrier.output + val sortOrder = logicalPlan.output .filter(attr => RowOrdering.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { - Sort(sortOrder, global = false, planWithBarrier) + Sort(sortOrder, global = false, logicalPlan) } else { // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism cache() - planWithBarrier + logicalPlan } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) @@ -1944,7 +1941,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -1985,7 +1982,7 @@ class Dataset[T] private[sql]( withPlan { Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, planWithBarrier) + qualifier = None, generatorOutput = Nil, logicalPlan) } } @@ -2100,7 +2097,7 @@ class Dataset[T] private[sql]( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.planWithBarrier.output + val attrs = this.logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) @@ -2148,7 +2145,7 @@ class Dataset[T] private[sql]( } cols } - Deduplicate(groupCols, planWithBarrier, isStreaming) + Deduplicate(groupCols, logicalPlan, isStreaming) } /** @@ -2297,7 +2294,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: T => Boolean): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2311,7 +2308,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def filter(func: FilterFunction[T]): Dataset[T] = { - withTypedPlan(TypedFilter(func, planWithBarrier)) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2325,7 +2322,7 @@ class Dataset[T] private[sql]( @Experimental @InterfaceStability.Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, planWithBarrier) + MapElements[T, U](func, logicalPlan) } /** @@ -2340,7 +2337,7 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, planWithBarrier)) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2356,7 +2353,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, planWithBarrier), + MapPartitions[T, U](func, logicalPlan), implicitly[Encoder[U]]) } @@ -2387,7 +2384,7 @@ class Dataset[T] private[sql]( val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] Dataset.ofRows( sparkSession, - MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) + MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan)) } /** @@ -2557,7 +2554,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, planWithBarrier) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -2571,7 +2568,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { - RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions) } /** @@ -2587,8 +2584,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression( - partitionExprs.map(_.expr), planWithBarrier, - sparkSession.sessionState.conf.numShufflePartitions) + partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions) } /** @@ -2609,7 +2605,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, planWithBarrier) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** @@ -2698,7 +2694,7 @@ class Dataset[T] private[sql]( */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](planWithBarrier) + val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } @@ -2812,7 +2808,7 @@ class Dataset[T] private[sql]( comment = None, properties = Map.empty, originalText = None, - child = planWithBarrier, + child = logicalPlan, allowExisting = false, replace = replace, viewType = viewType) @@ -2981,7 +2977,7 @@ class Dataset[T] private[sql]( } } withTypedPlan { - Sort(sortOrder, global = global, planWithBarrier) + Sort(sortOrder, global = global, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 958715eefa0a2..08c78e6e326af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -416,7 +416,7 @@ case class DataSource( }.head } // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This + // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. InsertIntoHadoopFsRelationCommand( outputPath = outputPath, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5f65898f5312e..3f4a78580f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -38,7 +38,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if maybeSQLFile(u) => try { val dataSource = DataSource( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d02c8ffe33f0f..4d155d538d637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext { test("collapse adjacent repartitions") { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length - assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 662fc80661513..9c60d22d35ce1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -88,7 +88,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } } - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => // Finds the database name if the name does not exist. val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) @@ -115,7 +115,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { } class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta @@ -146,7 +146,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { * `PreprocessTableInsertion`. */ object HiveAnalysis extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) From 382fefd1879e4670f3e9e8841ec243e3eb11c578 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 30 May 2017 22:33:29 -0700 Subject: [PATCH 047/133] [SPARK-20877][SPARKR][WIP] add timestamps to test runs ## What changes were proposed in this pull request? to investigate how long they run ## How was this patch tested? Jenkins, AppVeyor Author: Felix Cheung Closes #18104 from felixcheung/rtimetest. --- R/pkg/inst/tests/testthat/test_Windows.R | 3 + .../testthat/test_mllib_classification.R | 4 + .../tests/testthat/test_mllib_clustering.R | 2 + R/pkg/inst/tests/testthat/test_mllib_tree.R | 82 +++++++++++-------- R/pkg/inst/tests/testthat/test_sparkSQL.R | 15 ++++ R/pkg/inst/tests/testthat/test_utils.R | 3 + R/pkg/tests/run-all.R | 6 ++ 7 files changed, 81 insertions(+), 34 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 919b063bf0693..00d684e1a49ef 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -27,3 +27,6 @@ test_that("sparkJars tag in SparkContext", { abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) + +message("--- End test (Windows) ", as.POSIXct(Sys.time(), tz = "GMT")) +message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index c1c746828d24b..82e588dc460d0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.svmLinear", { + skip_on_cran() + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) @@ -226,6 +228,8 @@ test_that("spark.logit", { }) test_that("spark.mlp", { + skip_on_cran() + df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 8f71de1cbc7b5..e827e961ab4c1 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.bisectingKmeans", { + skip_on_cran() + newIris <- iris newIris$Species <- NULL training <- suppressWarnings(createDataFrame(newIris)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index 5fd6a38ecb4fa..31427ee52a5e9 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.gbt", { + skip_on_cran() + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) @@ -103,10 +105,12 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), - source = "libsvm") - model <- spark.gbt(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 692) + if (not_cran_or_windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) + } }) test_that("spark.randomForest", { @@ -211,13 +215,17 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) + if (not_cran_or_windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) + } }) test_that("spark.decisionTree", { + skip_on_cran() + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) @@ -234,19 +242,21 @@ test_that("spark.decisionTree", { expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) - modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) - unlink(modelPath) + unlink(modelPath) + } # classification data <- suppressWarnings(createDataFrame(iris)) @@ -263,17 +273,19 @@ test_that("spark.decisionTree", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) - unlink(modelPath) + unlink(modelPath) + } # Test numeric response variable labelToIndex <- function(species) { @@ -297,10 +309,12 @@ test_that("spark.decisionTree", { expect_equal(length(grep("2.0", predictions)), 50) # spark.decisionTree classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.decisionTree(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) + if (not_cran_or_windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.decisionTree(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9fc6e5dabecc3..c790d02b107be 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1395,6 +1395,8 @@ test_that("column operators", { }) test_that("column functions", { + skip_on_cran() + c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) @@ -1780,6 +1782,8 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { }) test_that("group by, agg functions", { + skip_on_cran() + df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -2121,6 +2125,8 @@ test_that("filter() on a DataFrame", { }) test_that("join(), crossJoin() and merge() on a DataFrame", { + skip_on_cran() + df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -2978,6 +2984,7 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { + skip_on_cran() df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -2999,6 +3006,8 @@ test_that("dapplyCollect() on DataFrame with a binary column", { }) test_that("repartition by columns on DataFrame", { + skip_on_cran() + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3037,6 +3046,8 @@ test_that("repartition by columns on DataFrame", { }) test_that("coalesce, repartition, numPartitions", { + skip_on_cran() + df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) expect_equal(getNumPartitions(coalesce(df, 3)), 3) @@ -3056,6 +3067,8 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { + skip_on_cran() + df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3208,6 +3221,8 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { }) test_that("randomSplit", { + skip_on_cran() + num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) weights <- c(2, 3, 5) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 02691f0f64314..6197ae7569879 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -243,3 +243,6 @@ test_that("basenameSansExtFromUrl", { }) sparkR.session.stop() + +message("--- End test (utils) ", as.POSIXct(Sys.time(), tz = "GMT")) +message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 9c6cba535d118..f0bef4f6d2662 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,6 +21,12 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} +message("--- Start test ", as.POSIXct(Sys.time(), tz = "GMT")) +timer_ptm <- proc.time() + # Setup global test environment # Install Spark first to set SPARK_HOME install.spark() From beed5e20af0a3935ef42beb3431c8630599bf27f Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Wed, 31 May 2017 11:24:37 +0100 Subject: [PATCH 048/133] [DOCS][MINOR] Scaladoc fixes (aka typo hunting) ## What changes were proposed in this pull request? Minor changes to scaladoc ## How was this patch tested? Local build Author: Jacek Laskowski Closes #18074 from jaceklaskowski/scaladoc-fixes. --- .../spark/sql/catalyst/ScalaReflection.scala | 6 ++-- .../sql/catalyst/analysis/Analyzer.scala | 5 ++-- .../catalyst/encoders/ExpressionEncoder.scala | 5 ++-- .../expressions/codegen/CodeGenerator.scala | 2 +- .../codegen/GenerateUnsafeProjection.scala | 6 ++-- .../expressions/windowExpressions.scala | 8 +++--- .../sql/catalyst/planning/QueryPlanner.scala | 14 ++++++---- .../spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 4 +-- .../spark/sql/RelationalGroupedDataset.scala | 24 ++++++++-------- .../spark/sql/execution/SparkPlan.scala | 28 ++++++++++--------- .../sql/execution/WholeStageCodegenExec.scala | 14 +++++----- .../execution/window/AggregateProcessor.scala | 17 +++++------ .../sql/execution/window/WindowExec.scala | 4 +-- .../apache/spark/sql/expressions/Window.scala | 4 +-- .../spark/sql/expressions/WindowSpec.scala | 2 +- .../org/apache/spark/sql/functions.scala | 4 +-- .../internal/BaseSessionStateBuilder.scala | 2 +- .../spark/sql/internal/SessionState.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 2 +- 20 files changed, 82 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6d1d019cc4743..87130532c89bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -88,8 +88,10 @@ object ScalaReflection extends ScalaReflection { } /** - * Given a type `T` this function constructs and ObjectType that holds a class of type - * Array[T]. Special handling is performed for primitive types to map them back to their raw + * Given a type `T` this function constructs `ObjectType` that holds a class of type + * `Array[T]`. + * + * Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 29183fd131292..196b4a9bada3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -85,8 +85,7 @@ object AnalysisContext { /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and - * [[UnresolvedRelation]]s into fully typed objects using information in a - * [[SessionCatalog]] and a [[FunctionRegistry]]. + * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. */ class Analyzer( catalog: SessionCatalog, @@ -1900,7 +1899,7 @@ class Analyzer( * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. * - * @return (seq of expressions containing at lease one window expressions, + * @return (seq of expressions containing at least one window expression, * seq of non-window expressions) */ private def extract( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index ec003cdc17b89..efc2882f0a3d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -208,7 +208,8 @@ object ExpressionEncoder { } /** - * A generic encoder for JVM objects. + * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` + * and a `deserializer`. * * @param schema The schema after converting `T` to a Spark SQL row. * @param serializer A set of expressions, one for each top-level field that can be used to @@ -235,7 +236,7 @@ case class ExpressionEncoder[T]( assert(serializer.flatMap { ser => val boundRefs = ser.collect { case b: BoundReference => b } assert(boundRefs.nonEmpty, - "each serializer expression should contains at least one `BoundReference`") + "each serializer expression should contain at least one `BoundReference`") boundRefs }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f8da78b5f5e3e..fd9780245fcfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -800,7 +800,7 @@ class CodegenContext { /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression - * elimination will be performed. Subexpression elimination assumes that the code will for each + * elimination will be performed. Subexpression elimination assumes that the code for each * expression will be combined in the `expressions` order. */ def generateExpressions(expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b358102d914bd..efbbc038bd33b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.types._ /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. * - * It generates the code for all the expressions, compute the total length for all the columns - * (can be accessed via variables), and then copy the data into a scratch buffer space in the + * It generates the code for all the expressions, computes the total length for all the columns + * (can be accessed via variables), and then copies the data into a scratch buffer space in the * form of UnsafeRow (the scratch buffer will grow as needed). * - * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection. + * @note The returned UnsafeRow will be pointed to a scratch buffer inside the projection. */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 37190429fc423..88afd43223d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -113,7 +113,7 @@ sealed trait FrameType * or a [[ValueFollowing]] is used as its [[FrameBoundary]], the value is considered * as a physical offset. * For example, `ROW BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a 3-row frame, - * from the row precedes the current row to the row follows the current row. + * from the row that precedes the current row to the row that follows the current row. */ case object RowFrame extends FrameType @@ -126,7 +126,7 @@ case object RowFrame extends FrameType * `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a frame containing rows whose values * `expr` are in the range of [v-1, v+1]. * - * If `ORDER BY` clause is not defined, all rows in the partition is considered as peers + * If `ORDER BY` clause is not defined, all rows in the partition are considered as peers * of the current row. */ case object RangeFrame extends FrameType @@ -217,11 +217,11 @@ case object UnboundedFollowing extends FrameBoundary { } /** - * The trait used to represent the a Window Frame. + * Represents a window frame. */ sealed trait WindowFrame -/** Used as a place holder when a frame specification is not defined. */ +/** Used as a placeholder when a frame specification is not defined. */ case object UnspecifiedFrame extends WindowFrame /** A specified Window Frame. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5f694f44b6e8a..bc41dd0465e34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode /** * Given a [[LogicalPlan]], returns a list of `PhysicalPlan`s that can - * be used for execution. If this strategy does not apply to the give logical operation then an + * be used for execution. If this strategy does not apply to the given logical operation then an * empty list should be returned. */ abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { @@ -42,9 +42,10 @@ abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends L * Abstract class for transforming [[LogicalPlan]]s into physical plans. * Child classes are responsible for specifying a list of [[GenericStrategy]] objects that * each of which can return a list of possible physical plan options. - * If a given strategy is unable to plan all - * of the remaining operators in the tree, it can call [[planLater]], which returns a placeholder - * object that will be filled in using other available strategies. + * If a given strategy is unable to plan all of the remaining operators in the tree, + * it can call [[GenericStrategy#planLater planLater]], which returns a placeholder + * object that will be [[collectPlaceholders collected]] and filled in + * using other available strategies. * * TODO: RIGHT NOW ONLY ONE PLAN IS RETURNED EVER... * PLAN SPACE EXPLORATION WILL BE IMPLEMENTED LATER. @@ -93,7 +94,10 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { pruned } - /** Collects placeholders marked as [[planLater]] by strategy and its [[LogicalPlan]]s */ + /** + * Collects placeholders marked using [[GenericStrategy#planLater planLater]] + * by [[strategies]]. + */ protected def collectPlaceholders(plan: PhysicalPlan): Seq[(PhysicalPlan, LogicalPlan)] /** Prunes bad plans to prevent combinatorial explosion. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2109c1c23b706..df66f9a082aee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -519,7 +519,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { protected def innerChildren: Seq[TreeNode[_]] = Seq.empty /** - * Appends the string represent of this node and its children to the given StringBuilder. + * Appends the string representation of this node and its children to the given StringBuilder. * * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b23ab1fa3514a..7e1f1d83cb3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1152,7 +1152,7 @@ class Column(val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } /** - * Define a windowing column. + * Defines a windowing column. * * {{{ * val w = Window.partitionBy("name").orderBy("id") @@ -1168,7 +1168,7 @@ class Column(val expr: Expression) extends Logging { def over(window: expressions.WindowSpec): Column = window.withAggregate(this) /** - * Define a empty analytic clause. In this case the analytic function is applied + * Defines an empty analytic clause. In this case the analytic function is applied * and presented for all rows in the result set. * * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 64755434784a0..147b549964913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.types.NumericType import org.apache.spark.sql.types.StructType /** - * A set of methods for aggregations on a `DataFrame`, created by `Dataset.groupBy`. + * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], + * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`). * - * The main method is the agg function, which has multiple variants. This class also contains - * convenience some first order statistics such as mean, sum for convenience. + * The main method is the `agg` function, which has multiple variants. This class also contains + * some first-order statistics such as `mean`, `sum` for convenience. * - * This class was named `GroupedData` in Spark 1.x. + * @note This class was named `GroupedData` in Spark 1.x. * * @since 2.0.0 */ @@ -297,8 +298,9 @@ class RelationalGroupedDataset protected[sql]( } /** - * Pivots a column of the current `DataFrame` and perform the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * + * There are two versions of `pivot` function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * @@ -337,7 +339,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Pivots a column of the current `DataFrame` and perform the specified aggregation. + * Pivots a column of the current `DataFrame` and performs the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -369,7 +371,9 @@ class RelationalGroupedDataset protected[sql]( } /** - * Pivots a column of the current `DataFrame` and perform the specified aggregation. + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -433,10 +437,6 @@ class RelationalGroupedDataset protected[sql]( } } - -/** - * Companion object for GroupedData. - */ private[sql] object RelationalGroupedDataset { def apply( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index c4ed96640eb19..db975614c961a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,24 +72,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Return all metadata that describes more details of this SparkPlan. + * @return Metadata that describes more details of this SparkPlan. */ def metadata: Map[String, String] = Map.empty /** - * Return all metrics containing metrics of this SparkPlan. + * @return All metrics containing metrics of this SparkPlan. */ def metrics: Map[String, SQLMetric] = Map.empty /** - * Reset all the metrics. + * Resets all the metrics. */ def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } /** - * Return a LongSQLMetric according to the name. + * @return [[SQLMetric]] for the `name`. */ def longMetric(name: String): SQLMetric = metrics(name) @@ -128,7 +128,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Execute a query after preparing the query and adding query plan information to created RDDs + * Executes a query after preparing the query and adding query plan information to created RDDs * for visualization. */ protected final def executeQuery[T](query: => T): T = { @@ -176,7 +176,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private var prepared = false /** - * Prepare a SparkPlan for execution. It's idempotent. + * Prepares this SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { // doPrepare() may depend on it's children, we should call prepare() on all the children first. @@ -195,22 +195,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * `execute` of SparkPlan. This is helpful if we want to set up some state before executing the * query, e.g., `BroadcastHashJoin` uses it to broadcast asynchronously. * - * Note: the prepare method has already walked down the tree, so the implementation doesn't need - * to call children's prepare methods. + * @note `prepare` method has already walked down the tree, so the implementation doesn't have + * to call children's `prepare` methods. * * This will only be called once, protected by `this`. */ protected def doPrepare(): Unit = {} /** + * Produces the result of the query as an `RDD[InternalRow]` + * * Overridden by concrete implementations of SparkPlan. - * Produces the result of the query as an RDD[InternalRow] */ protected def doExecute(): RDD[InternalRow] /** - * Overridden by concrete implementations of SparkPlan. * Produces the result of the query as a broadcast variable. + * + * Overridden by concrete implementations of SparkPlan. */ protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast") @@ -245,7 +247,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Decode the byte arrays back to UnsafeRows and put them into buffer. + * Decodes the byte arrays back to UnsafeRows and put them into buffer. */ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { val nFields = schema.length @@ -284,7 +286,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an iterator of InternalRow. * - * Note: this will trigger multiple jobs (one for each partition). + * @note Triggers multiple jobs (one for each partition). */ def executeToIterator(): Iterator[InternalRow] = { getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) @@ -301,7 +303,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the first `n` rows as an array. * - * This is modeled after RDD.take but never runs any job locally on the driver. + * This is modeled after `RDD.take` but never runs any job locally on the driver. */ def executeTake(n: Int): Array[InternalRow] = { if (n == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c1e1a631c677e..ac30b11557adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -70,7 +70,7 @@ trait CodegenSupport extends SparkPlan { /** * Returns all the RDDs of InternalRow which generates the input rows. * - * Note: right now we support up to two RDDs. + * @note Right now we support up to two RDDs */ def inputRDDs(): Seq[RDD[InternalRow]] @@ -227,7 +227,7 @@ trait CodegenSupport extends SparkPlan { /** - * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * InputAdapter is used to hide a SparkPlan from a subtree that supports codegen. * * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. @@ -282,10 +282,10 @@ object WholeStageCodegenExec { } /** - * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. * - * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * Here is the call graph of to generate Java source (plan A supports codegen, but plan B does not): * * WholeStageCodegen Plan A FakeInput Plan B * ========================================================================= @@ -304,10 +304,10 @@ object WholeStageCodegenExec { * | * doConsume() <-------- consume() * - * SparkPlan A should override doProduce() and doConsume(). + * SparkPlan A should override `doProduce()` and `doConsume()`. * - * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, - * used to generated code for BoundReference. + * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input, + * used to generated code for [[BoundReference]]. */ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index c9f5d3b3d92d7..bc141b36e63b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -26,17 +26,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ /** * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a - * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, - * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way + * that reduces the processing of a [[AggregateWindowFunction]] to processing the underlying * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. * * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions - * require the size of the partition processed, this value is exposed to them when the processor is - * constructed. + * require the size of the partition processed and this value is exposed to them when + * the processor is constructed. * * Processing of distinct aggregates is currently not supported. * - * The implementation is split into an object which takes care of construction, and a the actual + * The implementation is split into an object which takes care of construction, and the actual * processor class. */ private[window] object AggregateProcessor { @@ -90,7 +90,7 @@ private[window] object AggregateProcessor { updateExpressions ++= noOps evaluateExpressions += imperative case other => - sys.error(s"Unsupported Aggregate Function: $other") + sys.error(s"Unsupported aggregate function: $other") } // Create the projections. @@ -154,6 +154,7 @@ private[window] final class AggregateProcessor( } /** Evaluate buffer. */ - def evaluate(target: InternalRow): Unit = - evaluateProjection.target(target)(buffer) + def evaluate(target: InternalRow): Unit = { + evaluateProjection.target(target)(buffer) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 950a6794a74a3..1820cb0ef540b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -153,8 +153,8 @@ case class WindowExec( } /** - * Collection containing an entry for each window frame to process. Each entry contains a frames' - * WindowExpressions and factory function for the WindowFrameFunction. + * Collection containing an entry for each window frame to process. Each entry contains a frame's + * [[WindowExpression]]s and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { type FrameKey = (String, FrameType, Option[Int], Option[Int]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 00053485e614c..cd79128d8f375 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -170,7 +170,7 @@ object Window { * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * - * A range based boundary is based on the actual value of the ORDER BY + * A range-based boundary is based on the actual value of the ORDER BY * expression(s). An offset is used to alter the value of the ORDER BY expression, for * instance if the current order by expression has a value of 10 and the lower bound offset * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a @@ -184,7 +184,7 @@ object Window { * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") * val byCategoryOrderedById = - * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 6279d48c94de5..f653890f6c7ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -137,7 +137,7 @@ class WindowSpec private[sql]( * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * - * A range based boundary is based on the actual value of the ORDER BY + * A range-based boundary is based on the actual value of the ORDER BY * expression(s). An offset is used to alter the value of the ORDER BY expression, for * instance if the current order by expression has a value of 10 and the lower bound offset * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a347991d8490b..67ec1325b321e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1266,7 +1266,7 @@ object functions { /** * Parses the expression string into the column that it represents, similar to - * DataFrame.selectExpr + * [[Dataset#selectExpr]]. * {{{ * // get the number of words of each length * df.groupBy(expr("length(word)")).count() @@ -2386,7 +2386,7 @@ object functions { def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** - * * Return the soundex code for the specified expression. + * Returns the soundex code for the specified expression. * * @group string_funcs * @since 1.5.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2a801d87b12eb..2532b2ddb72df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -57,7 +57,7 @@ abstract class BaseSessionStateBuilder( type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder /** - * Function that produces a new instance of the SessionStateBuilder. This is used by the + * Function that produces a new instance of the `BaseSessionStateBuilder`. This is used by the * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own * [[SessionStateBuilder]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 1b341a12fc609..ac013ecf12ce0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -109,7 +109,7 @@ private[sql] object SessionState { } /** - * Concrete implementation of a [[SessionStateBuilder]]. + * Concrete implementation of a [[BaseSessionStateBuilder]]. */ @Experimental @InterfaceStability.Unstable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 86eeb2f7dd419..6057a795c8bf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -91,7 +91,7 @@ trait RelationProvider { * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that * users need to provide a schema when using a [[SchemaRelationProvider]]. - * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] + * A relation provider can inherit both [[RelationProvider]] and [[SchemaRelationProvider]] * if it can support both schema inference and user-specified schemas. * * @since 1.3.0 From d52f636228e833db89045bc7a0c17b72da13f138 Mon Sep 17 00:00:00 2001 From: David Eis Date: Wed, 31 May 2017 13:52:55 +0100 Subject: [PATCH 049/133] [SPARK-20790][MLLIB] Correctly handle negative values for implicit feedback in ALS ## What changes were proposed in this pull request? Revert the handling of negative values in ALS with implicit feedback, so that the confidence is the absolute value of the rating and the preference is 0 for negative ratings. This was the original behavior. ## How was this patch tested? This patch was tested with the existing unit tests and an added unit test to ensure that negative ratings are not ignored. mengxr Author: David Eis Closes #18022 from davideis/bugfix/negative-rating. --- .../apache/spark/ml/recommendation/ALS.scala | 22 ++++---- .../spark/ml/recommendation/ALSSuite.scala | 50 ++++++++++++++++++- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 0955d3e6e1f8f..3d5fd1794de23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -763,11 +763,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** * Representing a normal equation to solve the following weighted least squares problem: * - * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - d,,i,,)^2^ + lambda * x^T^ x. * * Its normal equation is given by * - * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - d,,i,, a,,i,,) + lambda * x = 0. + * + * Distributing and letting b,,i,, = c,,i,, * d,,i,, + * + * \sum,,i,, c,,i,, a,,i,, a,,i,,^T^ x - b,,i,, a,,i,, + lambda * x = 0. */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { @@ -796,7 +800,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { copyToDouble(a) blas.dspr(upper, k, c, da, 1, ata) if (b != 0.0) { - blas.daxpy(k, c * b, da, 1, atb, 1) + blas.daxpy(k, b, da, 1, atb, 1) } this } @@ -1624,15 +1628,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - // Extension to the original paper to handle b < 0. confidence is a function of |b| - // instead so that it is never negative. c1 is confidence - 1.0. + // Extension to the original paper to handle rating < 0. confidence is a function + // of |rating| instead so that it is never negative. c1 is confidence - 1. val c1 = alpha * math.abs(rating) - // For rating <= 0, the corresponding preference is 0. So the term below is only added - // for rating > 0. Because YtY is already added, we need to adjust the scaling here. - if (rating > 0) { + // For rating <= 0, the corresponding preference is 0. So the second argument of add + // is only there for rating > 0. + if (rating > 0.0) { numExplicits += 1 - ls.add(srcFactor, (c1 + 1.0) / c1, c1) } + ls.add(srcFactor, if (rating > 0.0) 1.0 + c1 else 0.0, c1) } else { ls.add(srcFactor, rating) numExplicits += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9d31e792633cd..701040f2d6041 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.ml.recommendation.ALS.Rating import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} @@ -78,7 +79,7 @@ class ALSSuite val k = 2 val ne0 = new NormalEquation(k) .add(Array(1.0f, 2.0f), 3.0) - .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted + .add(Array(4.0f, 5.0f), 12.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) // NumPy code that computes the expected values: @@ -347,6 +348,37 @@ class ALSSuite ALSSuite.genFactors(size, rank, random, a, b) } + /** + * Train ALS using the given training set and parameters + * @param training training dataset + * @param rank rank of the matrix factorization + * @param maxIter max number of iterations + * @param regParam regularization constant + * @param implicitPrefs whether to use implicit preference + * @param numUserBlocks number of user blocks + * @param numItemBlocks number of item blocks + * @return a trained ALSModel + */ + def trainALS( + training: RDD[Rating[Int]], + rank: Int, + maxIter: Int, + regParam: Double, + implicitPrefs: Boolean = false, + numUserBlocks: Int = 2, + numItemBlocks: Int = 3): ALSModel = { + val spark = this.spark + import spark.implicits._ + val als = new ALS() + .setRank(rank) + .setRegParam(regParam) + .setImplicitPrefs(implicitPrefs) + .setNumUserBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + .setSeed(0) + als.fit(training.toDF()) + } + /** * Test ALS using the given training/test splits and parameters. * @param training training dataset @@ -455,6 +487,22 @@ class ALSSuite targetRMSE = 0.3) } + test("implicit feedback regression") { + val trainingWithNeg = sc.parallelize(Array(Rating(0, 0, 1), Rating(1, 1, 1), Rating(0, 1, -3))) + val trainingWithZero = sc.parallelize(Array(Rating(0, 0, 1), Rating(1, 1, 1), Rating(0, 1, 0))) + val modelWithNeg = + trainALS(trainingWithNeg, rank = 1, maxIter = 5, regParam = 0.01, implicitPrefs = true) + val modelWithZero = + trainALS(trainingWithZero, rank = 1, maxIter = 5, regParam = 0.01, implicitPrefs = true) + val userFactorsNeg = modelWithNeg.userFactors + val itemFactorsNeg = modelWithNeg.itemFactors + val userFactorsZero = modelWithZero.userFactors + val itemFactorsZero = modelWithZero.itemFactors + userFactorsNeg.collect().foreach(arr => logInfo(s"implicit test " + arr.mkString(" "))) + userFactorsZero.collect().foreach(arr => logInfo(s"implicit test " + arr.mkString(" "))) + assert(userFactorsNeg.intersect(userFactorsZero).count() == 0) + assert(itemFactorsNeg.intersect(itemFactorsZero).count() == 0) + } test("using generic ID types") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) From ac7fc3075b7323261346fb4cd38c26f3b8f08bc2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 31 May 2017 10:46:23 -0500 Subject: [PATCH 050/133] [SPARK-20288] Avoid generating the MapStatus by stageId in BasicSchedulerIntegrationSuite ## What changes were proposed in this pull request? ShuffleId is determined before job submitted. But it's hard to predict stageId by shuffleId. Stage is created in DAGScheduler( https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L381), but the order is n ot determined in `HashSet`. I added a log(println(s"Creating ShufflMapStage-$id on shuffle-${shuffleDep.shuffleId}")) after (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L331), when testing BasicSchedulerIntegrationSuite:"multi-stage job". It will print: Creating ShufflMapStage-0 on shuffle-0 Creating ShufflMapStage-1 on shuffle-2 Creating ShufflMapStage-2 on shuffle-1 Creating ShufflMapStage-3 on shuffle-3 or Creating ShufflMapStage-0 on shuffle-1 Creating ShufflMapStage-1 on shuffle-3 Creating ShufflMapStage-2 on shuffle-0 Creating ShufflMapStage-3 on shuffle-2 It might be better to avoid generating the MapStatus by stageId. Author: jinxing Closes #17603 from jinxing64/SPARK-20288. --- .../spark/scheduler/SchedulerIntegrationSuite.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 37b08980db877..a8249e123fa00 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -553,10 +553,10 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor */ testScheduler("multi-stage job") { - def stageToOutputParts(stageId: Int): Int = { - stageId match { + def shuffleIdToOutputParts(shuffleId: Int): Int = { + shuffleId match { case 0 => 10 - case 2 => 20 + case 1 => 20 case _ => 30 } } @@ -577,11 +577,12 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor // b/c the stage numbering is non-deterministic, so stage number alone doesn't tell // us what to check } - (task.stageId, task.stageAttemptId, task.partitionId) match { case (stage, 0, _) if stage < 4 => + val shuffleId = + scheduler.stageIdToStage(stage).asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId backend.taskSuccess(taskDescription, - DAGSchedulerSuite.makeMapStatus("hostA", stageToOutputParts(stage))) + DAGSchedulerSuite.makeMapStatus("hostA", shuffleIdToOutputParts(shuffleId))) case (4, 0, partition) => backend.taskSuccess(taskDescription, 4321 + partition) } From d0f36bcb10c3f424e87a6a38def0c0a3b60c03d1 Mon Sep 17 00:00:00 2001 From: Liu Shaohui Date: Wed, 31 May 2017 10:53:31 -0500 Subject: [PATCH 051/133] [SPARK-20633][SQL] FileFormatWriter should not wrap FetchFailedException ## What changes were proposed in this pull request? Explicitly handle the FetchFailedException in FileFormatWriter, so it does not get wrapped. Note that this is no longer strictly necessary after SPARK-19276, but it improves error messages and also will help avoid others stumbling across this in the future. ## How was this patch tested? Existing unit tests. Closes https://github.com/apache/spark/pull/17893 Author: Liu Shaohui Closes #18145 from squito/SPARK-20633. --- .../spark/sql/execution/datasources/FileFormatWriter.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index afe454f714c47..0daffa93b4747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -31,6 +31,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -259,8 +260,10 @@ object FileFormatWriter extends Logging { } }) } catch { + case e: FetchFailedException => + throw e case t: Throwable => - throw new SparkException("Task failed while writing rows", t) + throw new SparkException("Task failed while writing rows.", t) } } From de934e6718f86dc12ddbbcc1d174527979f0bb25 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 31 May 2017 11:38:43 -0700 Subject: [PATCH 052/133] [SPARK-19236][SQL][FOLLOW-UP] Added createOrReplaceGlobalTempView method ### What changes were proposed in this pull request? This PR does the following tasks: - Added since - Added the Python API - Added test cases ### How was this patch tested? Added test cases to both Scala and Python Author: gatorsmile Closes #18147 from gatorsmile/createOrReplaceGlobalTempView. --- python/pyspark/sql/dataframe.py | 17 ++++++ .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../sql/execution/GlobalTempViewSuite.scala | 60 +++++++++++-------- 3 files changed, 52 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fbe66f18a3613..8d8b9384783e6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -191,6 +191,23 @@ def createGlobalTempView(self, name): """ self._jdf.createGlobalTempView(name) + @since(2.2) + def createOrReplaceGlobalTempView(self, name): + """Creates or replaces a global temporary view using the given name. + + The lifetime of this temporary view is tied to this Spark application. + + >>> df.createOrReplaceGlobalTempView("people") + >>> df2 = df.filter(df.age > 3) + >>> df2.createOrReplaceGlobalTempView("people") + >>> df3 = spark.sql("select * from global_temp.people") + >>> sorted(df3.collect()) == sorted(df2.collect()) + True + >>> spark.catalog.dropGlobalTempView("people") + + """ + self._jdf.createOrReplaceGlobalTempView(name) + @property @since(1.4) def write(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5ffe32f61ee09..a9e487f464948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2786,6 +2786,7 @@ class Dataset[T] private[sql]( * view, e.g. `SELECT * FROM _global_temp.view1`. * * @group basic + * @since 2.2.0 */ def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan { createTempViewCommand(viewName, replace = true, global = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 5c63c6a414f93..a3d75b221ec3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -35,39 +35,47 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { private var globalTempDB: String = _ test("basic semantic") { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + + // If there is no database in table name, we should try local temp view first, if not found, + // try table/view in current database, which is "default" in this case. So we expect + // NoSuchTableException here. + intercept[NoSuchTableException](spark.table("src")) - // If there is no database in table name, we should try local temp view first, if not found, - // try table/view in current database, which is "default" in this case. So we expect - // NoSuchTableException here. - intercept[NoSuchTableException](spark.table("src")) + // Use qualified name to refer to the global temp view explicitly. + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // Use qualified name to refer to the global temp view explicitly. - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Table name without database will never refer to a global temp view. + intercept[NoSuchTableException](sql("DROP VIEW src")) - // Table name without database will never refer to a global temp view. - intercept[NoSuchTableException](sql("DROP VIEW src")) + sql(s"DROP VIEW $globalTempDB.src") + // The global temp view should be dropped successfully. + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - sql(s"DROP VIEW $globalTempDB.src") - // The global temp view should be dropped successfully. - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + // We can also use Dataset API to create global temp view + Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // We can also use Dataset API to create global temp view - Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Use qualified name to rename a global temp view. + sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) - // Use qualified name to rename a global temp view. - sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) + // Use qualified name to alter a global temp view. + sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) - // Use qualified name to alter a global temp view. - sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) + // We can also use Catalog API to drop global temp view + spark.catalog.dropGlobalTempView("src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) - // We can also use Catalog API to drop global temp view - spark.catalog.dropGlobalTempView("src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + // We can also use Dataset API to replace global temp view + Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) + } finally { + spark.catalog.dropGlobalTempView("src") + } } test("global temp view is shared among all sessions") { @@ -106,7 +114,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE TABLE LIKE should work for global temp view") { try { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) } finally { From 2bc3272880515649d3e10eba135831a2ed0e3465 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 31 May 2017 17:24:37 -0700 Subject: [PATCH 053/133] [SPARK-20894][SS] Resolve the checkpoint location in driver and use the resolved path in state store ## What changes were proposed in this pull request? When the user runs a Structured Streaming query in a cluster, if the driver uses the local file system, StateStore running in executors will throw a file-not-found exception. However, the current error is not obvious. This PR makes StreamExecution resolve the path in driver and uses the full path including the scheme part (such as `hdfs:/`, `file:/`) in StateStore. Then if the above error happens, StateStore will throw an error with this full path which starts with `file:/`, and it makes this error obvious: the checkpoint location is on the local file system. One potential minor issue is that the user cannot use different default file system settings in driver and executors (e.g., use a public HDFS address in driver and a private HDFS address in executors) after this change. However, since the batch query also has this issue (See https://github.com/apache/spark/blob/4bb6a53ebd06de3de97139a2dbc7c85fc3aa3e66/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala#L402), it doesn't make things worse. ## How was this patch tested? The new added test. Author: Shixiong Zhu Closes #18149 from zsxwing/SPARK-20894. --- .../execution/streaming/StreamExecution.scala | 16 +++++++++++----- .../spark/sql/streaming/StreamSuite.scala | 19 +++++++++++++++++++ .../sql/streaming/StreamingQuerySuite.scala | 4 ++-- .../test/DataStreamReaderWriterSuite.scala | 8 ++++---- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ab8608563c4fb..74f0f509bbf85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -58,7 +58,7 @@ case object TERMINATED extends State class StreamExecution( override val sparkSession: SparkSession, override val name: String, - val checkpointRoot: String, + private val checkpointRoot: String, analyzedPlan: LogicalPlan, val sink: Sink, val trigger: Trigger, @@ -84,6 +84,12 @@ class StreamExecution( private val startLatch = new CountDownLatch(1) private val terminationLatch = new CountDownLatch(1) + val resolvedCheckpointRoot = { + val checkpointPath = new Path(checkpointRoot) + val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + checkpointPath.makeQualified(fs.getUri(), fs.getWorkingDirectory()).toUri.toString + } + /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. @@ -154,7 +160,7 @@ class StreamExecution( case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" val source = dataSource.createSource(metadataPath) nextSourceId += 1 // We still need to use the previous `output` instead of `source.schema` as attributes in @@ -233,14 +239,14 @@ class StreamExecution( /** Returns the path of a file with `name` in the checkpoint directory. */ private def checkpointFile(name: String): String = - new Path(new Path(checkpointRoot), name).toUri.toString + new Path(new Path(resolvedCheckpointRoot), name).toUri.toString /** * Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]] * has been posted to all the listeners. */ def start(): Unit = { - logInfo(s"Starting $prettyIdString. Use $checkpointRoot to store the query checkpoint.") + logInfo(s"Starting $prettyIdString. Use $resolvedCheckpointRoot to store the query checkpoint.") microBatchThread.setDaemon(true) microBatchThread.start() startLatch.await() // Wait until thread started and QueryStart event has been posted @@ -374,7 +380,7 @@ class StreamExecution( // Delete the temp checkpoint only when the query didn't fail if (deleteCheckpointOnStop && exception.isEmpty) { - val checkpointPath = new Path(checkpointRoot) + val checkpointPath = new Path(resolvedCheckpointRoot) try { val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) fs.delete(checkpointPath, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 280f2dc27b4a7..4ede4fd9a035e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -617,6 +617,25 @@ class StreamSuite extends StreamTest { query.stop() } + test("should resolve the checkpoint path") { + withTempDir { dir => + val checkpointLocation = dir.getCanonicalPath + assert(!checkpointLocation.startsWith("file:/")) + val query = MemoryStream[Int].toDF + .writeStream + .option("checkpointLocation", checkpointLocation) + .format("console") + .start() + try { + val resolvedCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot + assert(resolvedCheckpointDir.startsWith("file:/")) + } finally { + query.stop() + } + } + } + testQuietly("specify custom state store provider") { val queryName = "memStream" val providerClassName = classOf[TestStateStoreProvider].getCanonicalName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b69536ed37463..0925646beb869 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -466,7 +466,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckAnswer(6, 3, 6, 3, 1, 1), AssertOnQuery("metadata log should contain only two files") { q => - val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toUri) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 assert(toTest.size == 2 && toTest.head == "1") @@ -492,7 +492,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckAnswer(1, 2, 1, 2, 3, 4, 5, 6, 7, 8), AssertOnQuery("metadata log should contain three files") { q => - val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toUri) val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 assert(toTest.size == 3 && toTest.head == "2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index dc2506a48ad00..b5f1e28d7396a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -378,14 +378,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"$checkpointLocationURI/sources/0"), + meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/0"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"$checkpointLocationURI/sources/1"), + meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/1"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) @@ -642,7 +642,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() val checkpointDir = new Path( - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot) + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointDir)) query.stop() @@ -654,7 +654,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val input = MemoryStream[Int] val query = input.toDS.map(_ / 0).writeStream.format("console").start() val checkpointDir = new Path( - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot) + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.exists(checkpointDir)) input.addData(1) From 24db35826a81960f08e3eb68556b0f51781144e1 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 31 May 2017 17:26:18 -0700 Subject: [PATCH 054/133] [SPARK-20940][CORE] Replace IllegalAccessError with IllegalStateException ## What changes were proposed in this pull request? `IllegalAccessError` is a fatal error (a subclass of LinkageError) and its meaning is `Thrown if an application attempts to access or modify a field, or to call a method that it does not have access to`. Throwing a fatal error for AccumulatorV2 is not necessary and is pretty bad because it usually will just kill executors or SparkContext ([SPARK-20666](https://issues.apache.org/jira/browse/SPARK-20666) is an example of killing SparkContext due to `IllegalAccessError`). I think the correct type of exception in AccumulatorV2 should be `IllegalStateException`. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18168 from zsxwing/SPARK-20940. --- core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala | 4 ++-- core/src/test/scala/org/apache/spark/AccumulatorSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 1a9a6929541aa..603c23abb6895 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -68,7 +68,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { private def assertMetadataNotNull(): Unit = { if (metadata == null) { - throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.") + throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.") } } @@ -265,7 +265,7 @@ private[spark] object AccumulatorContext { // Since we are storing weak references, we must check whether the underlying data is valid. val acc = ref.get if (acc eq null) { - throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id") + throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") } acc } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index ddbcb2d19dcbb..3990ee1ec326d 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -210,7 +210,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(ref.get.isEmpty) // Getting a garbage collected accum should throw error - intercept[IllegalAccessError] { + intercept[IllegalStateException] { AccumulatorContext.get(accId) } From 5854f77ce1d3b9491e2a6bd1f352459da294e369 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 May 2017 22:34:53 -0700 Subject: [PATCH 055/133] [SPARK-20244][CORE] Handle incorrect bytesRead metrics when using PySpark ## What changes were proposed in this pull request? Hadoop FileSystem's statistics in based on thread local variables, this is ok if the RDD computation chain is running in the same thread. But if child RDD creates another thread to consume the iterator got from Hadoop RDDs, the bytesRead computation will be error, because now the iterator's `next()` and `close()` may run in different threads. This could be happened when using PySpark with PythonRDD. So here building a map to track the `bytesRead` for different thread and add them together. This method will be used in three RDDs, `HadoopRDD`, `NewHadoopRDD` and `FileScanRDD`. I assume `FileScanRDD` cannot be called directly, so I only fixed `HadoopRDD` and `NewHadoopRDD`. ## How was this patch tested? Unit test and local cluster verification. Author: jerryshao Closes #17617 from jerryshao/SPARK-20244. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 28 +++++++++++++---- .../org/apache/spark/rdd/HadoopRDD.scala | 8 ++++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 8 ++++- .../metrics/InputOutputMetricsSuite.scala | 31 ++++++++++++++++++- 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 9cc321af4bde2..6afe58bff5229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -23,6 +23,7 @@ import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import com.google.common.primitives.Longs @@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging { * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will * return the bytes read on r since t. - * - * @return None if the required method can't be found. */ private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { - val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) - val f = () => threadStats.map(_.getBytesRead).sum - val baselineBytesRead = f() - () => f() - baselineBytesRead + val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum + val baseline = (Thread.currentThread().getId, f()) + + /** + * This function may be called in both spawned child threads and parent task thread (in + * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics. + * So we need a map to track the bytes read from the child threads and parent thread, + * summing them together to get the bytes read of this task. + */ + new Function0[Long] { + private val bytesReadMap = new mutable.HashMap[Long, Long]() + + override def apply(): Long = { + bytesReadMap.synchronized { + bytesReadMap.put(Thread.currentThread().getId, f()) + bytesReadMap.map { case (k, v) => + v - (if (k == baseline._1) baseline._2 else 0) + }.sum + } + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4bf8ecc383542..76ea8b86c53d2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -251,7 +251,13 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener { context => + // Update the bytes read before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + closeIfNeeded() + } + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ce3a9a2a1e2a8..482875e6c1ac5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -191,7 +191,13 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) + context.addTaskCompletionListener { context => + // Update the bytesRead before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + close() + } + private var havePair = false private var recordsSinceMetricsUpdate = 0 diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 5d522189a0c29..6f4203da1d866 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } assert(bytesRead >= tmpFile.length()) } + + test("input metrics with old Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } } /** From 34661d8a5acbeecae9b034a2a6a737f16d8738bb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 31 May 2017 22:39:25 -0700 Subject: [PATCH 056/133] [SPARK-20708][CORE] Make `addExclusionRules` up-to-date ## What changes were proposed in this pull request? Since [SPARK-9263](https://issues.apache.org/jira/browse/SPARK-9263), `resolveMavenCoordinates` ignores Spark and Spark's dependencies by using `addExclusionRules`. This PR aims to make [addExclusionRules](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L956-L974) up-to-date to neglect correctly because it fails to neglect some components like the following. **mllib (correct)** ``` $ bin/spark-shell --packages org.apache.spark:spark-mllib_2.11:2.1.1 ... --------------------------------------------------------------------- | | modules || artifacts | | conf | number| search|dwnlded|evicted|| number|dwnlded| --------------------------------------------------------------------- | default | 0 | 0 | 0 | 0 || 0 | 0 | --------------------------------------------------------------------- ``` **mllib-local (wrong)** ``` $ bin/spark-shell --packages org.apache.spark:spark-mllib-local_2.11:2.1.1 ... --------------------------------------------------------------------- | | modules || artifacts | | conf | number| search|dwnlded|evicted|| number|dwnlded| --------------------------------------------------------------------- | default | 15 | 2 | 2 | 0 || 15 | 2 | --------------------------------------------------------------------- ``` ## How was this patch tested? Pass the Jenkins with a updated test case. Author: Dongjoon Hyun Closes #17947 from dongjoon-hyun/SPARK-20708. --- .../org/apache/spark/deploy/SparkSubmit.scala | 17 ++++++++++------- .../spark/deploy/SparkSubmitUtilsSuite.scala | 9 +++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index c60a2a1706d5a..d13fb4193970b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -879,6 +879,15 @@ private[spark] object SparkSubmitUtils { // Exposed for testing var printStream = SparkSubmit.printStream + // Exposed for testing. + // These components are used to make the default exclusion rules for Spark dependencies. + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and + // other spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x + val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "launcher_", "mllib_", + "mllib-local_", "network-common_", "network-shuffle_", "repl_", "sketch_", "sql_", "streaming_", + "tags_", "unsafe_") + /** * Represents a Maven Coordinate * @param groupId the groupId of the coordinate @@ -1007,13 +1016,7 @@ private[spark] object SparkSubmitUtils { // Add scala exclusion rule md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) - // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and - // other spark-streaming utility components. Underscore is there to differentiate between - // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x - val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", - "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") - - components.foreach { comp => + IVY_DEFAULT_EXCLUDES.foreach { comp => md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, ivyConfName)) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 266c9d33b5a96..57024786b95e3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -187,12 +187,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("neglects Spark and Spark's dependencies") { - val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", - "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") - - val coordinates = - components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + - ",org.apache.spark:spark-core_fake:1.2.0" + val coordinates = SparkSubmitUtils.IVY_DEFAULT_EXCLUDES + .map(comp => s"org.apache.spark:spark-${comp}2.11:2.1.1") + .mkString(",") + ",org.apache.spark:spark-core_fake:1.2.0" val path = SparkSubmitUtils.resolveMavenCoordinates( coordinates, From c8045f8b482e347eccf2583e0952e1d8bcb6cb96 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 31 May 2017 23:17:15 -0700 Subject: [PATCH 057/133] [MINOR][SQL] Fix a few function description error. ## What changes were proposed in this pull request? Fix a few function description error. ## How was this patch tested? manual tests ![descissues](https://cloud.githubusercontent.com/assets/5399861/26619392/d547736c-4610-11e7-85d7-aeeb09c02cc8.gif) Author: Yuming Wang Closes #18157 from wangyum/DescIssues. --- .../apache/spark/sql/catalyst/expressions/mathExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/stringExpressions.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 7b64568c69659..615256243ae2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -283,7 +283,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" > SELECT _FUNC_('100', 2, 10); 4 > SELECT _FUNC_(-10, 16, -10); - 16 + -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cc4d465c5d701..035a1afe8b782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1047,7 +1047,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC """, extended = """ Examples: - > SELECT initcap('sPark sql'); + > SELECT _FUNC_('sPark sql'); Spark Sql """) case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { From 6d05c1c1da9104c903099a6790b39427c867ed2b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Jun 2017 16:15:24 +0900 Subject: [PATCH 058/133] [SPARK-20910][SQL] Add build-in SQL function - UUID ## What changes were proposed in this pull request? Add build-int SQL function - UUID. ## How was this patch tested? unit tests Author: Yuming Wang Closes #18136 from wangyum/SPARK-20910. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 28 +++++++++++++++++++ .../expressions/MiscExpressionsSuite.scala | 5 ++++ .../sql-tests/inputs/string-functions.sql | 3 ++ .../results/string-functions.sql.out | 10 ++++++- 5 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8081036bed8a6..116b26f612e02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -380,6 +380,7 @@ object FunctionRegistry { expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), expression[Md5]("md5"), + expression[Uuid]("uuid"), expression[Murmur3Hash]("hash"), expression[Sha1]("sha"), expression[Sha1]("sha1"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index bb9368cf6d774..3fc4bb7041636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.UUID + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Print the result of an expression to stderr (used for debugging codegen). @@ -104,3 +107,28 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { override def nullable: Boolean = false override def prettyName: String = "current_database" } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + extended = """ + Examples: + > SELECT _FUNC_(); + 46707d92-02f4-4817-8116-a4c3b23e6266 + """) +// scalastyle:on line.size.limit +case class Uuid() extends LeafExpression { + + override def deterministic: Boolean = false + + override def nullable: Boolean = false + + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.copy(code = s"final UTF8String ${ev.value} = " + + s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a26d070a99c52..4fe7b436982b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -39,4 +39,9 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) } + test("uuid") { + checkEvaluation(Length(Uuid()), 36) + assert(evaluate(Uuid()) !== evaluate(Uuid())) + } + } diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index e6dcea4972c18..d82df11251c5b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -12,3 +12,6 @@ FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10)) t; -- replace function select replace('abc', 'b', '123'); select replace('abc', 'b'); + +-- uuid +select length(uuid()), (uuid() <> uuid()); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index abf0cc44d6e42..4093a7b9fc820 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 7 -- !query 0 @@ -70,3 +70,11 @@ select replace('abc', 'b') struct -- !query 5 output ac + + +-- !query 6 +select length(uuid()), (uuid() <> uuid()) +-- !query 6 schema +struct +-- !query 6 output +36 true From 0975019cd475f990585b1b436d25373501cfadf7 Mon Sep 17 00:00:00 2001 From: John Compitello Date: Thu, 1 Jun 2017 05:42:42 -0400 Subject: [PATCH 059/133] [SPARK-20109][MLLIB] Rewrote toBlockMatrix method on IndexedRowMatrix ## What changes were proposed in this pull request? - ~~I added the method `toBlockMatrixDense` to the IndexedRowMatrix class. The current implementation of `toBlockMatrix` is insufficient for users with relatively dense IndexedRowMatrix objects, since it assumes sparsity.~~ EDIT: Ended up deciding that there should be just a single `toBlockMatrix` method, which creates a BlockMatrix whose blocks may be dense or sparse depending on the sparsity of the rows. This method will work better on any current use case of `toBlockMatrix` and doesn't go through `CoordinateMatrix` like the old method. ## How was this patch tested? ~~I used the same tests already written for `toBlockMatrix()` to test this method. I also added a new additional unit test for an edge case that was not adequately tested by current test suite.~~ I ran the original `IndexedRowMatrix` tests, plus wrote more to better handle edge cases ignored by original tests. Author: John Compitello Closes #17459 from johnc1231/johnc-fix-ir-to-block. --- .../linalg/distributed/CoordinateMatrix.scala | 7 ++ .../linalg/distributed/IndexedRowMatrix.scala | 70 +++++++++++++- .../distributed/IndexedRowMatrixSuite.scala | 91 +++++++++++++++++-- 3 files changed, 157 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 26ca1ef9be870..0d223de9b6f7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -125,6 +125,13 @@ class CoordinateMatrix @Since("1.0.0") ( s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") val m = numRows() val n = numCols() + + // Since block matrices require an integer row and col index + require(math.ceil(m.toDouble / rowsPerBlock) <= Int.MaxValue, + "Number of rows divided by rowsPerBlock cannot exceed maximum integer.") + require(math.ceil(n.toDouble / colsPerBlock) <= Int.MaxValue, + "Number of cols divided by colsPerBlock cannot exceed maximum integer.") + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index d7255d527f036..8890662d99b52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -91,7 +91,7 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Converts to BlockMatrix. Creates blocks of `SparseMatrix` with size 1024 x 1024. + * Converts to BlockMatrix. Creates blocks with size 1024 x 1024. */ @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { @@ -99,7 +99,7 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Converts to BlockMatrix. Creates blocks of `SparseMatrix`. + * Converts to BlockMatrix. Blocks may be sparse or dense depending on the sparsity of the rows. * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have * a smaller value. Must be an integer value greater than 0. * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have @@ -108,8 +108,70 @@ class IndexedRowMatrix @Since("1.0.0") ( */ @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { - // TODO: This implementation may be optimized - toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) + require(rowsPerBlock > 0, + s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") + require(colsPerBlock > 0, + s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") + + val m = numRows() + val n = numCols() + + // Since block matrices require an integer row index + require(math.ceil(m.toDouble / rowsPerBlock) <= Int.MaxValue, + "Number of rows divided by rowsPerBlock cannot exceed maximum integer.") + + // The remainder calculations only matter when m % rowsPerBlock != 0 or n % colsPerBlock != 0 + val remainderRowBlockIndex = m / rowsPerBlock + val remainderColBlockIndex = n / colsPerBlock + val remainderRowBlockSize = (m % rowsPerBlock).toInt + val remainderColBlockSize = (n % colsPerBlock).toInt + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt + val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt + + val blocks = rows.flatMap { ir: IndexedRow => + val blockRow = ir.index / rowsPerBlock + val rowInBlock = ir.index % rowsPerBlock + + ir.vector match { + case SparseVector(size, indices, values) => + indices.zip(values).map { case (index, value) => + val blockColumn = index / colsPerBlock + val columnInBlock = index % colsPerBlock + ((blockRow.toInt, blockColumn.toInt), (rowInBlock.toInt, Array((value, columnInBlock)))) + } + case DenseVector(values) => + values.grouped(colsPerBlock) + .zipWithIndex + .map { case (values, blockColumn) => + ((blockRow.toInt, blockColumn), (rowInBlock.toInt, values.zipWithIndex)) + } + } + }.groupByKey(GridPartitioner(numRowBlocks, numColBlocks, rows.getNumPartitions)).map { + case ((blockRow, blockColumn), itr) => + val actualNumRows = + if (blockRow == remainderRowBlockIndex) remainderRowBlockSize else rowsPerBlock + val actualNumColumns = + if (blockColumn == remainderColBlockIndex) remainderColBlockSize else colsPerBlock + + val arraySize = actualNumRows * actualNumColumns + val matrixAsArray = new Array[Double](arraySize) + var countForValues = 0 + itr.foreach { case (rowWithinBlock, valuesWithColumns) => + valuesWithColumns.foreach { case (value, columnWithinBlock) => + matrixAsArray.update(columnWithinBlock * actualNumRows + rowWithinBlock, value) + countForValues += 1 + } + } + val denseMatrix = new DenseMatrix(actualNumRows, actualNumColumns, matrixAsArray) + val finalMatrix = if (countForValues / arraySize.toDouble >= 0.1) { + denseMatrix + } else { + denseMatrix.toSparse + } + + ((blockRow, blockColumn), finalMatrix) + } + new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 99af5fa10d999..566ce95be084a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Matrices, Vectors} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -87,19 +87,96 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(coordMat.toBreeze() === idxRowMat.toBreeze()) } - test("toBlockMatrix") { - val idxRowMat = new IndexedRowMatrix(indexedRows) - val blockMat = idxRowMat.toBlockMatrix(2, 2) + test("toBlockMatrix dense backing") { + val idxRowMatDense = new IndexedRowMatrix(indexedRows) + + // Tests when n % colsPerBlock != 0 + val blockMat = idxRowMatDense.toBlockMatrix(2, 2) assert(blockMat.numRows() === m) assert(blockMat.numCols() === n) - assert(blockMat.toBreeze() === idxRowMat.toBreeze()) + assert(blockMat.toBreeze() === idxRowMatDense.toBreeze()) + + // Tests when m % rowsPerBlock != 0 + val blockMat2 = idxRowMatDense.toBlockMatrix(3, 1) + assert(blockMat2.numRows() === m) + assert(blockMat2.numCols() === n) + assert(blockMat2.toBreeze() === idxRowMatDense.toBreeze()) intercept[IllegalArgumentException] { - idxRowMat.toBlockMatrix(-1, 2) + idxRowMatDense.toBlockMatrix(-1, 2) } intercept[IllegalArgumentException] { - idxRowMat.toBlockMatrix(2, 0) + idxRowMatDense.toBlockMatrix(2, 0) } + + assert(blockMat.blocks.map { case (_, matrix: Matrix) => + matrix.isInstanceOf[DenseMatrix] + }.reduce(_ && _)) + assert(blockMat2.blocks.map { case (_, matrix: Matrix) => + matrix.isInstanceOf[DenseMatrix] + }.reduce(_ && _)) + } + + test("toBlockMatrix sparse backing") { + val sparseData = Seq( + (15L, Vectors.sparse(12, Seq((0, 4.0)))) + ).map(x => IndexedRow(x._1, x._2)) + + // Gonna make m and n larger here so the matrices can easily be completely sparse: + val m = 16 + val n = 12 + + val idxRowMatSparse = new IndexedRowMatrix(sc.parallelize(sparseData)) + + // Tests when n % colsPerBlock != 0 + val blockMat = idxRowMatSparse.toBlockMatrix(8, 8) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMatSparse.toBreeze()) + + // Tests when m % rowsPerBlock != 0 + val blockMat2 = idxRowMatSparse.toBlockMatrix(6, 6) + assert(blockMat2.numRows() === m) + assert(blockMat2.numCols() === n) + assert(blockMat2.toBreeze() === idxRowMatSparse.toBreeze()) + + assert(blockMat.blocks.collect().forall{ case (_, matrix: Matrix) => + matrix.isInstanceOf[SparseMatrix] + }) + assert(blockMat2.blocks.collect().forall{ case (_, matrix: Matrix) => + matrix.isInstanceOf[SparseMatrix] + }) + } + + test("toBlockMatrix mixed backing") { + val m = 24 + val n = 18 + + val mixedData = Seq( + (0L, Vectors.dense((0 to 17).map(_.toDouble).toArray)), + (1L, Vectors.dense((0 to 17).map(_.toDouble).toArray)), + (23L, Vectors.sparse(18, Seq((0, 4.0))))) + .map(x => IndexedRow(x._1, x._2)) + + val idxRowMatMixed = new IndexedRowMatrix( + sc.parallelize(mixedData)) + + // Tests when n % colsPerBlock != 0 + val blockMat = idxRowMatMixed.toBlockMatrix(12, 12) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMatMixed.toBreeze()) + + // Tests when m % rowsPerBlock != 0 + val blockMat2 = idxRowMatMixed.toBlockMatrix(18, 6) + assert(blockMat2.numRows() === m) + assert(blockMat2.numCols() === n) + assert(blockMat2.toBreeze() === idxRowMatMixed.toBreeze()) + + val blocks = blockMat.blocks.collect() + + assert(blocks.forall { case((row, col), matrix) => + if (row == 0) matrix.isInstanceOf[DenseMatrix] else matrix.isInstanceOf[SparseMatrix]}) } test("multiply a local matrix") { From f7cf2096fdecb8edab61c8973c07c6fc877ee32d Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 1 Jun 2017 09:52:18 -0700 Subject: [PATCH 060/133] [SPARK-20941][SQL] Fix SubqueryExec Reuse ### What changes were proposed in this pull request? Before this PR, Subquery reuse does not work. Below are three issues: - Subquery reuse does not work. - It is sharing the same `SQLConf` (`spark.sql.exchange.reuse`) with the one for Exchange Reuse. - No test case covers the rule Subquery reuse. This PR is to fix the above three issues. - Ignored the physical operator `SubqueryExec` when comparing two plans. - Added a dedicated conf `spark.sql.subqueries.reuse` for controlling Subquery Reuse - Added a test case for verifying the behavior ### How was this patch tested? N/A Author: Xiao Li Closes #18169 from gatorsmile/subqueryReuse. --- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../execution/basicPhysicalOperators.scala | 3 ++ .../apache/spark/sql/execution/subquery.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 +++++++++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c6f5cf641b8d5..1739b0cfa2761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -552,6 +552,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SUBQUERY_REUSE_ENABLED = buildConf("spark.sql.subquery.reuse") + .internal() + .doc("When true, the planner will try to find out duplicated subqueries and re-use them.") + .booleanConf + .createWithDefault(true) + val STATE_STORE_PROVIDER_CLASS = buildConf("spark.sql.streaming.stateStore.providerClass") .internal() @@ -932,6 +938,8 @@ class SQLConf extends Serializable with Logging { def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) + def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED) + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 85096dcc40f5d..f69a688555bbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -595,6 +595,9 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa */ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index d11045fb6ac8c..2abeadfe45362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -156,7 +156,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.exchangeReuseEnabled) { + if (!conf.subqueryReuseEnabled) { return plan } // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b525c9e80ba42..41e9e2c92ca8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,9 +23,12 @@ import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.execution.{ScalarSubquery, SubqueryExec} import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -700,6 +703,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } + test("Verify spark.sql.subquery.reuse") { + Seq(true, false).foreach { reuse => + withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) { + val df = sql( + """ + |SELECT key, (SELECT avg(key) FROM testData) + |FROM testData + |WHERE key > (SELECT avg(key) FROM testData) + |ORDER BY key + |LIMIT 3 + """.stripMargin) + + checkAnswer(df, Row(51, 50.5) :: Row(52, 50.5) :: Row(53, 50.5) :: Nil) + + val subqueries = ArrayBuffer.empty[SubqueryExec] + df.queryExecution.executedPlan.transformAllExpressions { + case s @ ScalarSubquery(plan: SubqueryExec, _) => + subqueries += plan + s + } + + assert(subqueries.size == 2, "Two ScalarSubquery are expected in the plan") + + if (reuse) { + assert(subqueries.distinct.size == 1, "Only one ScalarSubquery exists in the plan") + } else { + assert(subqueries.distinct.size == 2, "Reuse is not expected") + } + } + } + } + test("cartesian product join") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer( From 640afa49aa349c7ebe35d365eec3ef9bb7710b1d Mon Sep 17 00:00:00 2001 From: Li Yichao Date: Thu, 1 Jun 2017 14:39:57 -0700 Subject: [PATCH 061/133] [SPARK-20365][YARN] Remove local scheme when add path to ClassPath. In Spark on YARN, when configuring "spark.yarn.jars" with local jars (jars started with "local" scheme), we will get inaccurate classpath for AM and containers. This is because we don't remove "local" scheme when concatenating classpath. It is OK to run because classpath is separated with ":" and java treat "local" as a separate jar. But we could improve it to remove the scheme. Updated `ClientSuite` to check "local" is not in the classpath. cc jerryshao Author: Li Yichao Author: Li Yichao Closes #18129 from liyichao/SPARK-20365. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 3 ++- .../test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 9956071fd6e38..1fb7edf2a6e30 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1275,7 +1275,8 @@ private object Client extends Logging { if (sparkConf.get(SPARK_ARCHIVE).isEmpty) { sparkConf.get(SPARK_JARS).foreach { jars => jars.filter(isLocalUri).foreach { jar => - addClasspathEntry(getClusterPath(sparkConf, jar), env) + val uri = new URI(jar) + addClasspathEntry(getClusterPath(sparkConf, uri.getPath()), env) } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 3a11787aa57dc..6cf68427921fd 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -122,6 +122,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll cp should not contain (uri.getPath()) } }) + cp should not contain ("local") cp should contain(PWD) cp should contain (s"$PWD${Path.SEPARATOR}${LOCALIZED_CONF_DIR}") cp should not contain (APP_JAR) From 8efc6e986554ae66eab93cd64a9035d716adbab0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 1 Jun 2017 14:44:34 -0700 Subject: [PATCH 062/133] [SPARK-20922][CORE] Add whitelist of classes that can be deserialized by the launcher. Blindly deserializing classes using Java serialization opens the code up to issues in other libraries, since just deserializing data from a stream may end up execution code (think readObject()). Since the launcher protocol is pretty self-contained, there's just a handful of classes it legitimately needs to deserialize, and they're in just two packages, so add a filter that throws errors if classes from any other package show up in the stream. This also maintains backwards compatibility (the updated launcher code can still communicate with the backend code in older Spark releases). Tested with new and existing unit tests. Author: Marcelo Vanzin Closes #18166 from vanzin/SPARK-20922. --- .../launcher/FilteredObjectInputStream.java | 53 +++++++++++ .../spark/launcher/LauncherConnection.java | 3 +- .../spark/launcher/LauncherServerSuite.java | 92 ++++++++++++++----- 3 files changed, 121 insertions(+), 27 deletions(-) create mode 100644 launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java diff --git a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java new file mode 100644 index 0000000000000..4d254a0c4c9fe --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.InputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.Arrays; +import java.util.List; + +/** + * An object input stream that only allows classes used by the launcher protocol to be in the + * serialized stream. See SPARK-20922. + */ +class FilteredObjectInputStream extends ObjectInputStream { + + private static final List ALLOWED_PACKAGES = Arrays.asList( + "org.apache.spark.launcher.", + "java.lang."); + + FilteredObjectInputStream(InputStream is) throws IOException { + super(is); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws IOException, ClassNotFoundException { + + boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p)); + if (!isValid) { + throw new IllegalArgumentException( + String.format("Unexpected class in stream: %s", desc.getName())); + } + return super.resolveClass(desc); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index eec264909bbb6..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -20,7 +20,6 @@ import java.io.Closeable; import java.io.EOFException; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.Socket; import java.util.logging.Level; @@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable { @Override public void run() { try { - ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); + FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); while (!closed) { Message msg = (Message) in.readObject(); handle(msg); diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 12f1a0ce2d1b4..03c2934e2692e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -19,8 +19,11 @@ import java.io.Closeable; import java.io.IOException; +import java.io.ObjectInputStream; import java.net.InetAddress; import java.net.Socket; +import java.util.Arrays; +import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; @@ -120,31 +123,7 @@ public void testTimeout() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), LauncherServer.getServerInstance().getPort()); client = new TestClient(s); - - // Try a few times since the client-side socket may not reflect the server-side close - // immediately. - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { - try { - if (!helloSent) { - client.send(new Hello(handle.getSecret(), "1.4.0")); - helloSent = true; - } else { - client.send(new SetAppId("appId")); - } - fail("Expected exception caused by connection timeout."); - } catch (IllegalStateException | IOException e) { - // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } - } - } + waitForError(client, handle.getSecret()); } finally { SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); kill(handle); @@ -183,6 +162,25 @@ public void infoChanged(SparkAppHandle handle) { } } + @Test + public void testStreamFiltering() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + + client = new TestClient(s); + client.send(new EvilPayload()); + waitForError(client, handle.getSecret()); + assertEquals(0, EvilPayload.EVIL_BIT); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + private void kill(SparkAppHandle handle) { if (handle != null) { handle.kill(); @@ -199,6 +197,35 @@ private void close(Closeable c) { } } + /** + * Try a few times to get a client-side error, since the client-side socket may not reflect the + * server-side close immediately. + */ + private void waitForError(TestClient client, String secret) throws Exception { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { + try { + if (!helloSent) { + client.send(new Hello(secret, "1.4.0")); + helloSent = true; + } else { + client.send(new SetAppId("appId")); + } + fail("Expected error but message went through."); + } catch (IllegalStateException | IOException e) { + // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } + } + } + } + private static class TestClient extends LauncherConnection { final BlockingQueue inbound; @@ -220,4 +247,19 @@ protected void handle(Message msg) throws IOException { } + private static class EvilPayload extends LauncherProtocol.Message { + + static int EVIL_BIT = 0; + + // This field should cause the launcher server to throw an error and not deserialize the + // message. + private List notAllowedField = Arrays.asList("disallowed"); + + private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException { + stream.defaultReadObject(); + EVIL_BIT = 1; + } + + } + } From 2134196a9c0aca82bc3e203c09e776a8bd064d65 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Thu, 1 Jun 2017 15:50:40 -0700 Subject: [PATCH 063/133] [SPARK-20854][SQL] Extend hint syntax to support expressions ## What changes were proposed in this pull request? SQL hint syntax: * support expressions such as strings, numbers, etc. instead of only identifiers as it is currently. * support multiple hints, which was missing compared to the DataFrame syntax. DataFrame API: * support any parameters in DataFrame.hint instead of just strings ## How was this patch tested? Existing tests. New tests in PlanParserSuite. New suite DataFrameHintSuite. Author: Bogdan Raducanu Closes #18086 from bogdanrdc/SPARK-20854. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../sql/catalyst/analysis/ResolveHints.scala | 8 +- .../spark/sql/catalyst/dsl/package.scala | 3 + .../sql/catalyst/parser/AstBuilder.scala | 11 +- .../sql/catalyst/plans/logical/hints.scala | 6 +- .../sql/catalyst/analysis/DSLHintSuite.scala | 53 ++++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 100 +++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../apache/spark/sql/DataFrameHintSuite.scala | 62 +++++++++++ 9 files changed, 225 insertions(+), 26 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 4584aea6196a6..43f7ff5cb4a36 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -371,7 +371,7 @@ querySpecification (RECORDREADER recordReader=STRING)? fromClause? (WHERE where=booleanExpression)?) - | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause? + | ((kind=SELECT (hints+=hint)* setQuantifier? namedExpressionSeq fromClause? | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) lateralView* (WHERE where=booleanExpression)? @@ -381,12 +381,12 @@ querySpecification ; hint - : '/*+' hintStatement '*/' + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' ; hintStatement : hintName=identifier - | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')' + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' ; fromClause diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 86c788aaa828a..62a3482d9fac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -91,7 +92,12 @@ object ResolveHints { ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true))) } else { // Otherwise, find within the subtree query plans that should be broadcasted. - applyBroadcastHint(h.child, h.parameters.toSet) + applyBroadcastHint(h.child, h.parameters.map { + case tableName: String => tableName + case tableId: UnresolvedAttribute => tableId.name + case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + + s"an identifier or string but was $unsupported (${unsupported.getClass}") + }.toSet) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index ed423e7e334b6..beee93d906f0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -381,6 +381,9 @@ package object dsl { def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) + + def hint(name: String, parameters: Any*): LogicalPlan = + UnresolvedHint(name, parameters, logicalPlan) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4eb5560155781..a16611af28a7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -407,7 +407,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val withWindow = withDistinct.optionalMap(windows)(withWindows) // Hint - withWindow.optionalMap(hint)(withHints) + hints.asScala.foldRight(withWindow)(withHints) } } @@ -533,13 +533,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Add a [[UnresolvedHint]] to a logical plan. + * Add [[UnresolvedHint]]s to a logical plan. */ private def withHints( ctx: HintContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - val stmt = ctx.hintStatement - UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + var plan = query + ctx.hintStatements.asScala.reverse.foreach { case stmt => + plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan) + } + plan } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 5fe6d2d8da064..d16fae56b3d4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -23,9 +23,11 @@ import org.apache.spark.sql.internal.SQLConf /** * A general hint for the child that is not yet resolved. This node is generated by the parser and * should be removed This node will be eliminated post analysis. - * A pair of (name, parameters). + * @param name the name of the hint + * @param parameters the parameters of the hint + * @param child the [[LogicalPlan]] on which this hint applies */ -case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalPlan) +case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala new file mode 100644 index 0000000000000..48a3ca2ccfb0b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class DSLHintSuite extends AnalysisTest { + lazy val a = 'a.int + lazy val b = 'b.string + lazy val c = 'c.string + lazy val r1 = LocalRelation(a, b, c) + + test("various hint parameters") { + comparePlans( + r1.hint("hint1"), + UnresolvedHint("hint1", Seq(), r1) + ) + + comparePlans( + r1.hint("hint1", 1, "a"), + UnresolvedHint("hint1", Seq(1, "a"), r1) + ) + + comparePlans( + r1.hint("hint1", 1, $"a"), + UnresolvedHint("hint1", Seq(1, $"a"), r1) + ) + + comparePlans( + r1.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), r1) + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 3a26adaef9db0..d004d04569772 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -527,19 +527,13 @@ class PlanParserSuite extends PlanTest { val m = intercept[ParseException] { parsePlan("SELECT /*+ HINT() */ * FROM t") }.getMessage - assert(m.contains("no viable alternative at input")) - - // Hive compatibility: No database. - val m2 = intercept[ParseException] { - parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t") - }.getMessage - assert(m2.contains("mismatched input '.' expecting {')', ','}")) + assert(m.contains("mismatched input")) // Disallow space as the delimiter. val m3 = intercept[ParseException] { parsePlan("SELECT /*+ INDEX(a b c) */ * from default.t") }.getMessage - assert(m3.contains("mismatched input 'b' expecting {')', ','}")) + assert(m3.contains("mismatched input 'b' expecting")) comparePlans( parsePlan("SELECT /*+ HINT */ * FROM t"), @@ -547,27 +541,103 @@ class PlanParserSuite extends PlanTest { comparePlans( parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), - UnresolvedHint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), - UnresolvedHint("MAPJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), - UnresolvedHint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), - UnresolvedHint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), - UnresolvedHint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")), + table("default.t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), - UnresolvedHint("MAPJOIN", Seq("t"), + UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) } + + test("SPARK-20854: select hint syntax with expressions") { + comparePlans( + parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", + UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)), + table("t").select(star()) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", + UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)), + table("t").select(star()) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 5, 'a', b) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(5), Literal("a"), $"b"), + table("t").select(star()) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1('a', (b, c), (1, 2)) */ * from t"), + UnresolvedHint("HINT1", + Seq(Literal("a"), + CreateStruct($"b" :: $"c" :: Nil), + CreateStruct(Literal(1) :: Literal(2) :: Nil)), + table("t").select(star()) + ) + ) + } + + test("SPARK-20854: multiple hints") { + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1) hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1),hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1) */ /*+ hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1), hint2(b, 2) */ /*+ hint3(c, 3) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + UnresolvedHint("hint3", Seq($"c", Literal(3)), + table("t").select(star()) + ) + ) + ) + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a9e487f464948..8abec85ee102a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1190,7 +1190,7 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { + def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { UnresolvedHint(name, parameters, logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala new file mode 100644 index 0000000000000..60f6f23860ed9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameHintSuite extends PlanTest with SharedSQLContext { + import testImplicits._ + lazy val df = spark.range(10) + + private def check(df: Dataset[_], expected: LogicalPlan) = { + comparePlans( + df.queryExecution.logical, + expected + ) + } + + test("various hint parameters") { + check( + df.hint("hint1"), + UnresolvedHint("hint1", Seq(), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", 1, "a"), + UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan) + ) + + check( + df.hint("hint1", 1, $"a"), + UnresolvedHint("hint1", Seq(1, $"a"), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), + df.logicalPlan + ) + ) + } +} From 0e31e28d483918c1de26f78068e78c2ca3cf7f3c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 2 Jun 2017 14:25:38 +0100 Subject: [PATCH 064/133] [MINOR][PYTHON] Ignore pep8 on test scripts generated in tests in work directory ## What changes were proposed in this pull request? Currently, if we run `./python/run-tests.py` and they are aborted without cleaning up this directory, it fails pep8 check due to some Python scripts generated. For example, https://github.com/apache/spark/blob/7387126f83dc0489eb1df734bfeba705709b7861/python/pyspark/tests.py#L1955-L1968 ``` PEP8 checks failed. ./work/app-20170531190857-0000/0/test.py:5:55: W292 no newline at end of file ./work/app-20170531190909-0000/0/test.py:5:55: W292 no newline at end of file ./work/app-20170531190924-0000/0/test.py:3:1: E302 expected 2 blank lines, found 1 ./work/app-20170531190924-0000/0/test.py:7:52: W292 no newline at end of file ./work/app-20170531191016-0000/0/test.py:5:55: W292 no newline at end of file ./work/app-20170531191030-0000/0/test.py:5:55: W292 no newline at end of file ./work/app-20170531191045-0000/0/test.py:3:1: E302 expected 2 blank lines, found 1 ./work/app-20170531191045-0000/0/test.py:7:52: W292 no newline at end of file ``` For me, it is sometimes a bit annoying. This PR proposes to exclude these (assuming we want to skip per https://github.com/apache/spark/blob/master/.gitignore#L73). Also, it moves other pep8 configurations in the script into ini configuration file in pep8. ## How was this patch tested? Manually tested via `./dev/lint-python`. Author: hyukjinkwon Closes #18161 from HyukjinKwon/work-exclude-pep8. --- dev/lint-python | 4 ++-- dev/tox.ini | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index c6f3fbfab84ed..07e2606d45143 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-geneated configuration file. -PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" -not -path "*python/docs/conf.py" )" +PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" @@ -64,7 +64,7 @@ export "PATH=$PYTHONPATH:$PATH" #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then diff --git a/dev/tox.ini b/dev/tox.ini index 76e3f42cde62d..eeeb637460cfb 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -14,5 +14,6 @@ # limitations under the License. [pep8] +ignore=E402,E731,E241,W503,E226 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py From 625cebfde632361122e0db3452c4cc38147f696f Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 2 Jun 2017 14:38:00 +0100 Subject: [PATCH 065/133] [SPARK-20942][WEB-UI] The title style about field is error in the history server web ui. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1.The title style about field is error. fix before: ![before](https://cloud.githubusercontent.com/assets/26266482/26661987/a7bed018-46b3-11e7-8a54-a5152d2df0f4.png) fix after: ![fix](https://cloud.githubusercontent.com/assets/26266482/26662000/ba6cc814-46b3-11e7-8f33-cfd4cc2c60fe.png) ![fix1](https://cloud.githubusercontent.com/assets/26266482/26662080/3c732e3e-46b4-11e7-8768-20b5a6aeadcb.png) executor-page style: ![executor_page](https://cloud.githubusercontent.com/assets/26266482/26662384/167cbd10-46b6-11e7-9e07-bf391dbc6e08.png) 2.Title text description, 'the application' should be changed to 'this application'. 3.Analysis of code: $('#history-summary [data-toggle="tooltip"]').tooltip(); The id of 'history-summary' is not there. We only contain id of 'history-summary-table'. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Author: 郭小龙 10207633 Author: guoxiaolongzte Closes #18170 from guoxiaolongzte/SPARK-20942. --- .../spark/ui/static/historypage-template.html | 18 +++++++++--------- .../org/apache/spark/ui/static/historypage.js | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index c2afa993b2f20..bfe31aae555ba 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -20,47 +20,47 @@ - + App ID - + App Name - + Attempt ID - + Started - + Completed - + Duration - + Spark User - + Last Updated - + Event Log diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 7db8c27e8f7c9..5ec1ce15a2127 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -195,7 +195,7 @@ $(document).ready(function() { } $(selector).DataTable(conf); - $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + $('#history-summary [data-toggle="tooltip"]').tooltip(); }); }); }); From d1b80ab9220d83e5fdaf33c513cc811dd17d0de1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Jun 2017 09:58:01 -0700 Subject: [PATCH 066/133] [SPARK-20967][SQL] SharedState.externalCatalog is not really lazy ## What changes were proposed in this pull request? `SharedState.externalCatalog` is marked as a `lazy val` but actually it's not lazy. We access `externalCatalog` while initializing `SharedState` and thus eliminate the effort of `lazy val`. When creating `ExternalCatalog` we will try to connect to the metastore and may throw an error, so it makes sense to make it a `lazy val` in `SharedState`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18187 from cloud-fan/minor. --- .../spark/sql/internal/SharedState.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index a93b701146077..7202f1222d10f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -90,38 +90,38 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = - SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + lazy val externalCatalog: ExternalCatalog = { + val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, sparkContext.hadoopConfiguration) - // Create the default database if it doesn't exist. - { val defaultDbDefinition = CatalogDatabase( SessionCatalog.DEFAULT_DATABASE, "default database", CatalogUtils.stringToURI(warehousePath), Map()) - // Initialize default database if it doesn't exist + // Create default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we // set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception. externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } - } - // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { - override def onEvent(event: ExternalCatalogEvent): Unit = { - sparkContext.listenerBus.post(event) - } - }) + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + + externalCatalog + } /** * A manager for global temporary views. */ - val globalTempViewManager: GlobalTempViewManager = { + lazy val globalTempViewManager: GlobalTempViewManager = { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. From e11d90bf8deb553fd41b8837e3856c11486c2503 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Jun 2017 10:05:05 -0700 Subject: [PATCH 067/133] [SPARK-20946][SQL] simplify the config setting logic in SparkSession.getOrCreate ## What changes were proposed in this pull request? The current conf setting logic is a little complex and has duplication, this PR simplifies it. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18172 from cloud-fan/session. --- .../spark/ml/recommendation/ALSSuite.scala | 4 +-- .../apache/spark/ml/tree/impl/TreeTests.scala | 2 -- .../org/apache/spark/sql/SparkSession.scala | 25 +++++++------------ 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 701040f2d6041..23f22566cc8e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -820,15 +820,13 @@ class ALSCleanerSuite extends SparkFunSuite { FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet try { conf.set("spark.local.dir", localDir.getAbsolutePath) - val sc = new SparkContext("local[2]", "test", conf) + val sc = new SparkContext("local[2]", "ALSCleanerSuite", conf) try { sc.setCheckpointDir(checkpointDir.getAbsolutePath) // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training val spark = SparkSession.builder - .master("local[2]") - .appName("ALSCleanerSuite") .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 92a236928e90b..b6894b30b0c2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -43,8 +43,6 @@ private[ml] object TreeTests extends SparkFunSuite { categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { val spark = SparkSession.builder() - .master("local[2]") - .appName("TreeTests") .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d2bf350711936..bf37b76087473 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -757,6 +757,8 @@ object SparkSession { private[this] var userSuppliedContext: Option[SparkContext] = None + // The `SparkConf` inside the given `SparkContext` may get changed if you specify some options + // for this builder. private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { userSuppliedContext = Option(sparkContext) this @@ -854,7 +856,7 @@ object SparkSession { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { f(extensions) this } @@ -899,22 +901,14 @@ object SparkSession { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { - // set app name if not given - val randomAppName = java.util.UUID.randomUUID().toString val sparkConf = new SparkConf() - options.foreach { case (k, v) => sparkConf.set(k, v) } - if (!sparkConf.contains("spark.app.name")) { - sparkConf.setAppName(randomAppName) - } - val sc = SparkContext.getOrCreate(sparkConf) - // maybe this is an existing SparkContext, update its SparkConf which maybe used - // by SparkSession - options.foreach { case (k, v) => sc.conf.set(k, v) } - if (!sc.conf.contains("spark.app.name")) { - sc.conf.setAppName(randomAppName) - } - sc + options.get("spark.master").foreach(sparkConf.setMaster) + // set a random app name if not given. + sparkConf.setAppName(options.getOrElse("spark.app.name", + java.util.UUID.randomUUID().toString)) + SparkContext.getOrCreate(sparkConf) } + options.foreach { case (k, v) => sparkContext.conf.set(k, v) } // Initialize extensions if the user has defined a configurator class. val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) @@ -935,7 +929,6 @@ object SparkSession { } session = new SparkSession(sparkContext, None, None, extensions) - options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the From 16186cdcbce1a2ec8f839c550e6b571bf5dc2692 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 2 Jun 2017 10:33:21 -0700 Subject: [PATCH 068/133] [SPARK-20955][CORE] Intern "executorId" to reduce the memory usage ## What changes were proposed in this pull request? In [this line](https://github.com/apache/spark/blob/f7cf2096fdecb8edab61c8973c07c6fc877ee32d/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala#L128), it uses the `executorId` string received from executors and finally it will go into `TaskUIData`. As deserializing the `executorId` string will always create a new instance, we have a lot of duplicated string instances. This PR does a String interning for TaskUIData to reduce the memory usage. ## How was this patch tested? Manually test using `bin/spark-shell --master local-cluster[6,1,1024]`. Test codes: ``` for (_ <- 1 to 10) { sc.makeRDD(1 to 1000, 1000).count() } Thread.sleep(2000) val l = sc.getClass.getMethod("jobProgressListener").invoke(sc).asInstanceOf[org.apache.spark.ui.jobs.JobProgressListener] org.apache.spark.util.SizeEstimator.estimate(l.stageIdToData) ``` This PR reduces the size of `stageIdToData` from 3487280 to 3009744 (86.3%) in the above case. Author: Shixiong Zhu Closes #18177 from zsxwing/SPARK-20955. --- .../scala/org/apache/spark/ui/jobs/UIData.scala | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 048c4ad0146e2..6764daa0df529 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -20,6 +20,8 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} @@ -141,6 +143,14 @@ private[spark] object UIData { } object TaskUIData { + + private val stringInterner = Interners.newWeakInterner[String]() + + /** String interning to reduce the memory usage. */ + private def weakIntern(s: String): String = { + stringInterner.intern(s) + } + def apply(taskInfo: TaskInfo): TaskUIData = { new TaskUIData(dropInternalAndSQLAccumulables(taskInfo)) } @@ -155,8 +165,8 @@ private[spark] object UIData { index = taskInfo.index, attemptNumber = taskInfo.attemptNumber, launchTime = taskInfo.launchTime, - executorId = taskInfo.executorId, - host = taskInfo.host, + executorId = weakIntern(taskInfo.executorId), + host = weakIntern(taskInfo.host), taskLocality = taskInfo.taskLocality, speculative = taskInfo.speculative ) From 2a780ac7fe21df7c336885f8e814c1b866e04285 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 2 Jun 2017 12:58:29 -0700 Subject: [PATCH 069/133] [MINOR][SQL] Update the description of spark.sql.files.ignoreCorruptFiles and spark.sql.columnNameOfCorruptRecord ### What changes were proposed in this pull request? 1. The description of `spark.sql.files.ignoreCorruptFiles` is not accurate. When the file does not exist, we will issue the error message. ``` org.apache.spark.sql.AnalysisException: Path does not exist: file:/nonexist/path; ``` 2. `spark.sql.columnNameOfCorruptRecord` also affects the CSV format. The current description only mentions JSON format. ### How was this patch tested? N/A Author: Xiao Li Closes #18184 from gatorsmile/updateMessage. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1739b0cfa2761..54bee02e44e43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,7 +345,8 @@ object SQLConf { .createWithDefault(true) val COLUMN_NAME_OF_CORRUPT_RECORD = buildConf("spark.sql.columnNameOfCorruptRecord") - .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.") + .doc("The name of internal column for storing raw/un-parsed JSON and CSV records that fail " + + "to parse.") .stringConf .createWithDefault("_corrupt_record") @@ -535,8 +536,7 @@ object SQLConf { val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupted or non-existing and contents that have been read will still be " + - "returned.") + "encountering corrupted files and the contents that have been read will still be returned.") .booleanConf .createWithDefault(false) From 0eb1fc6cd512f19d94758643c512cd6db036aaab Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 2 Jun 2017 15:36:21 -0700 Subject: [PATCH 070/133] Revert "[SPARK-20946][SQL] simplify the config setting logic in SparkSession.getOrCreate" This reverts commit e11d90bf8deb553fd41b8837e3856c11486c2503. --- .../spark/ml/recommendation/ALSSuite.scala | 4 ++- .../apache/spark/ml/tree/impl/TreeTests.scala | 2 ++ .../org/apache/spark/sql/SparkSession.scala | 25 ++++++++++++------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 23f22566cc8e7..701040f2d6041 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -820,13 +820,15 @@ class ALSCleanerSuite extends SparkFunSuite { FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet try { conf.set("spark.local.dir", localDir.getAbsolutePath) - val sc = new SparkContext("local[2]", "ALSCleanerSuite", conf) + val sc = new SparkContext("local[2]", "test", conf) try { sc.setCheckpointDir(checkpointDir.getAbsolutePath) // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training val spark = SparkSession.builder + .master("local[2]") + .appName("ALSCleanerSuite") .sparkContext(sc) .getOrCreate() import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..92a236928e90b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -43,6 +43,8 @@ private[ml] object TreeTests extends SparkFunSuite { categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { val spark = SparkSession.builder() + .master("local[2]") + .appName("TreeTests") .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index bf37b76087473..d2bf350711936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -757,8 +757,6 @@ object SparkSession { private[this] var userSuppliedContext: Option[SparkContext] = None - // The `SparkConf` inside the given `SparkContext` may get changed if you specify some options - // for this builder. private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { userSuppliedContext = Option(sparkContext) this @@ -856,7 +854,7 @@ object SparkSession { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { f(extensions) this } @@ -901,14 +899,22 @@ object SparkSession { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { + // set app name if not given + val randomAppName = java.util.UUID.randomUUID().toString val sparkConf = new SparkConf() - options.get("spark.master").foreach(sparkConf.setMaster) - // set a random app name if not given. - sparkConf.setAppName(options.getOrElse("spark.app.name", - java.util.UUID.randomUUID().toString)) - SparkContext.getOrCreate(sparkConf) + options.foreach { case (k, v) => sparkConf.set(k, v) } + if (!sparkConf.contains("spark.app.name")) { + sparkConf.setAppName(randomAppName) + } + val sc = SparkContext.getOrCreate(sparkConf) + // maybe this is an existing SparkContext, update its SparkConf which maybe used + // by SparkSession + options.foreach { case (k, v) => sc.conf.set(k, v) } + if (!sc.conf.contains("spark.app.name")) { + sc.conf.setAppName(randomAppName) + } + sc } - options.foreach { case (k, v) => sparkContext.conf.set(k, v) } // Initialize extensions if the user has defined a configurator class. val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) @@ -929,6 +935,7 @@ object SparkSession { } session = new SparkSession(sparkContext, None, None, extensions) + options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the From 6de41e951fd6172ab7d603474abded0ee7417cde Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Fri, 2 Jun 2017 17:36:00 -0700 Subject: [PATCH 071/133] [SPARK-17078][SQL][FOLLOWUP] Simplify explain cost command ## What changes were proposed in this pull request? Usually when using explain cost command, users want to see the stats of plan. Since stats is only showed in optimized plan, it is more direct and convenient to include only optimized plan and physical plan in the output. ## How was this patch tested? Enhanced existing test. Author: Zhenhua Wang Closes #18190 from wzhfy/simplifyExplainCost. --- .../spark/sql/execution/QueryExecution.scala | 28 +++++++++---------- .../sql/execution/command/commands.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 6 ++++ 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 1ba9a79446aad..34998cbd61552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -200,11 +200,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } - override def toString: String = completeString(appendStats = false) - - def toStringWithStats: String = completeString(appendStats = true) - - private def completeString(appendStats: Boolean): String = { + override def toString: String = { def output = Utils.truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") val analyzedPlan = Seq( @@ -212,25 +208,29 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { stringOrError(analyzed.treeString(verbose = true)) ).filter(_.nonEmpty).mkString("\n") - val optimizedPlanString = if (appendStats) { - // trigger to compute stats for logical plans - optimizedPlan.stats(sparkSession.sessionState.conf) - optimizedPlan.treeString(verbose = true, addSuffix = true) - } else { - optimizedPlan.treeString(verbose = true) - } - s"""== Parsed Logical Plan == |${stringOrError(logical.treeString(verbose = true))} |== Analyzed Logical Plan == |$analyzedPlan |== Optimized Logical Plan == - |${stringOrError(optimizedPlanString)} + |${stringOrError(optimizedPlan.treeString(verbose = true))} |== Physical Plan == |${stringOrError(executedPlan.treeString(verbose = true))} """.stripMargin.trim } + def stringWithStats: String = { + // trigger to compute stats for logical plans + optimizedPlan.stats(sparkSession.sessionState.conf) + + // only show optimized logical plan and physical plan + s"""== Optimized Logical Plan == + |${stringOrError(optimizedPlan.treeString(verbose = true, addSuffix = true))} + |== Physical Plan == + |${stringOrError(executedPlan.treeString(verbose = true))} + """.stripMargin.trim + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 99d81c49f1e3b..2d82fcf4da6e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -127,7 +127,7 @@ case class ExplainCommand( } else if (extended) { queryExecution.toString } else if (cost) { - queryExecution.toStringWithStats + queryExecution.stringWithStats } else { queryExecution.simpleString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index aa1ca2909074f..3066a4f305f00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -29,6 +29,12 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto import testImplicits._ test("show cost in explain command") { + // For readability, we only show optimized plan and physical plan in explain cost command + checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), + "Optimized Logical Plan", "Physical Plan") + checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), + "Parsed Logical Plan", "Analyzed Logical Plan") + // Only has sizeInBytes before ANALYZE command checkKeywordsExist(sql("EXPLAIN COST SELECT * FROM src "), "sizeInBytes") checkKeywordsNotExist(sql("EXPLAIN COST SELECT * FROM src "), "rowCount") From 864d94fe879a32de324da65a844e62a0260b222d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Jun 2017 21:59:52 -0700 Subject: [PATCH 072/133] [SPARK-20974][BUILD] we should run REPL tests if SQL module has code changes ## What changes were proposed in this pull request? REPL module depends on SQL module, so we should run REPL tests if SQL module has code changes. ## How was this patch tested? N/A Author: Wenchen Fan Closes #18191 from cloud-fan/test. --- dev/run-tests.py | 2 +- dev/sparktestsupport/modules.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 818a0c9f48419..72d148d7ea0fb 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -111,7 +111,7 @@ def determine_modules_to_test(changed_modules): >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', - 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] + 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() for module in changed_modules: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 78b5b8b0f4b59..2971e0db40496 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -123,6 +123,7 @@ def __hash__(self): ], ) + hive = Module( name="hive", dependencies=[sql], @@ -142,6 +143,18 @@ def __hash__(self): ) +repl = Module( + name="repl", + dependencies=[hive], + source_file_regexes=[ + "repl/", + ], + sbt_test_goals=[ + "repl/test", + ], +) + + hive_thriftserver = Module( name="hive-thriftserver", dependencies=[hive], From 6cbc61d1070584ffbc34b1f53df352c9162f414a Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Sat, 3 Jun 2017 14:56:42 +0900 Subject: [PATCH 073/133] [SPARK-19732][SQL][PYSPARK] Add fill functions for nulls in bool fields of datasets ## What changes were proposed in this pull request? Allow fill/replace of NAs with booleans, both in Python and Scala ## How was this patch tested? Unit tests, doctests This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro Author: Ruben Berenguel Closes #18164 from rberenguel/SPARK-19732-fillna-bools. --- python/pyspark/sql/dataframe.py | 23 ++++++++++--- python/pyspark/sql/tests.py | 34 ++++++++++++++----- .../spark/sql/DataFrameNaFunctions.scala | 30 ++++++++++++++-- .../spark/sql/DataFrameNaFunctionsSuite.scala | 21 ++++++++++++ 4 files changed, 94 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8d8b9384783e6..99abfcc556dff 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1289,7 +1289,7 @@ def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. - :param value: int, long, float, string, or dict. + :param value: int, long, float, string, bool or dict. Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be @@ -1309,6 +1309,15 @@ def fillna(self, value, subset=None): | 50| 50| null| +---+------+-----+ + >>> df5.na.fill(False).show() + +----+-------+-----+ + | age| name| spy| + +----+-------+-----+ + | 10| Alice|false| + | 5| Bob|false| + |null|Mallory| true| + +----+-------+-----+ + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| @@ -1319,10 +1328,13 @@ def fillna(self, value, subset=None): | 50| null|unknown| +---+------+-------+ """ - if not isinstance(value, (float, int, long, basestring, dict)): - raise ValueError("value should be a float, int, long, string, or dict") + if not isinstance(value, (float, int, long, basestring, bool, dict)): + raise ValueError("value should be a float, int, long, string, bool or dict") + + # Note that bool validates isinstance(int), but we don't want to + # convert bools to floats - if isinstance(value, (int, long)): + if not isinstance(value, bool) and isinstance(value, (int, long)): value = float(value) if isinstance(value, dict): @@ -1819,6 +1831,9 @@ def _test(): Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() + globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10), + Row(name='Bob', spy=None, age=5), + Row(name='Mallory', spy=True, age=None)]).toDF() globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), Row(name='Bob', time=1479442946)]).toDF() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index acea9113ee858..845e1c7619cc4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1697,40 +1697,58 @@ def test_fillna(self): schema = StructType([ StructField("name", StringType(), True), StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) + StructField("height", DoubleType(), True), + StructField("spy", BooleanType(), True)]) # fillna shouldn't change non-null values - row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int - row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double - row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) + # fillna with bool + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(True).first() + self.assertEqual(row.age, None) + self.assertEqual(row.spy, True) + # fillna with string - row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first() + row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols row = self.spark.createDataFrame( - [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) - # fillna with subset specified for numeric cols + # fillna with subset specified for string cols row = self.spark.createDataFrame( - [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) + + # fillna with subset specified for bool cols + row = self.spark.createDataFrame( + [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + self.assertEqual(row.spy, True) # fillna with dictionary for boolean types row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 052d85ad33bd6..ee949e78fa3ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -195,6 +195,30 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** + * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. + * + * @since 2.3.0 + */ + def fill(value: Boolean): DataFrame = fill(value, df.columns) + + /** + * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified + * boolean columns. If a specified column is not a boolean column, it is ignored. + * + * @since 2.3.0 + */ + def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols) + + /** + * Returns a new `DataFrame` that replaces null values in specified boolean columns. + * If a specified column is not a boolean column, it is ignored. + * + * @since 2.3.0 + */ + def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** * Returns a new `DataFrame` that replaces null values. * @@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric, string columns. If a specified column is not a numeric, string column, - * it is ignored. + * numeric, string columns. If a specified column is not a numeric, string + * or boolean column it is ignored. */ private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { // the fill[T] which T is Long/Double, @@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val targetType = value match { case _: Double | _: Long => NumericType case _: String => StringType + case _: Boolean => BooleanType case _ => throw new IllegalArgumentException( s"Unsupported value type ${value.getClass.getName} ($value).") } @@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType + case (BooleanType, dt) => dt == BooleanType } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aa237d0619ac3..e63c5cb194d68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { test("fill") { val input = createDF() + val boolInput = Seq[(String, java.lang.Boolean)]( + ("Bob", false), + ("Alice", null), + ("Mallory", true), + (null, null) + ).toDF("name", "spy") + val fillNumeric = input.na.fill(50.6) checkAnswer( fillNumeric, @@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + // boolean + checkAnswer( + boolInput.na.fill(true).select("spy"), + Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil) + assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq) + // fill double with subset columns checkAnswer( input.na.fill(50.6, "age" :: Nil).select("name", "age"), @@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row("Amy", 50) :: Row(null, 50) :: Nil) + // fill boolean with subset columns + checkAnswer( + boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"), + Row("Bob", false) :: + Row("Alice", true) :: + Row("Mallory", true) :: + Row(null, true) :: Nil) + // fill string with subset columns checkAnswer( Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), From 96e6ba6c2aaddd885ba6a842fbfcd73c5537f99e Mon Sep 17 00:00:00 2001 From: David Eis Date: Sat, 3 Jun 2017 09:48:10 +0100 Subject: [PATCH 074/133] [SPARK-20790][MLLIB] Remove extraneous logging in test ## What changes were proposed in this pull request? Remove extraneous logging. ## How was this patch tested? Unit tests pass. Author: David Eis Closes #18188 from davideis/fix-test. --- .../scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 701040f2d6041..3094f52ba1bc5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -498,8 +498,6 @@ class ALSSuite val itemFactorsNeg = modelWithNeg.itemFactors val userFactorsZero = modelWithZero.userFactors val itemFactorsZero = modelWithZero.itemFactors - userFactorsNeg.collect().foreach(arr => logInfo(s"implicit test " + arr.mkString(" "))) - userFactorsZero.collect().foreach(arr => logInfo(s"implicit test " + arr.mkString(" "))) assert(userFactorsNeg.intersect(userFactorsZero).count() == 0) assert(itemFactorsNeg.intersect(itemFactorsZero).count() == 0) } From 887cf0ec33ccf5bb7936c48cb07e21d60945bee5 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 3 Jun 2017 09:56:35 +0100 Subject: [PATCH 075/133] [SPARK-20936][CORE] Lack of an important case about the test of resolveURI in UtilsSuite, and add it as needed. ## What changes were proposed in this pull request? 1. add `assert(resolve(before) === after)` to check before and after in test of resolveURI. the function `assertResolves(before: String, after: String)` have two params, it means we should check the before value whether equals the after value which we want. e.g. the after value of Utils.resolveURI("hdfs:///root/spark.jar#app.jar").toString should be "hdfs:///root/spark.jar#app.jar" rather than "hdfs:/root/spark.jar#app.jar". we need `assert(resolve(before) === after)` to make it more safe. 2. identify the cases between resolveURI and resolveURIs. 3. delete duplicate cases and some small fix make this suit more clear. ## How was this patch tested? unit tests Author: zuotingbing Closes #18158 from zuotingbing/spark-UtilsSuite. --- .../org/apache/spark/util/UtilsSuite.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3339d5b35d3b2..f7bc8f888b0d5 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -461,19 +461,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { def assertResolves(before: String, after: String): Unit = { // This should test only single paths assume(before.split(",").length === 1) - // Repeated invocations of resolveURI should yield the same result def resolve(uri: String): String = Utils.resolveURI(uri).toString + assert(resolve(before) === after) assert(resolve(after) === after) + // Repeated invocations of resolveURI should yield the same result assert(resolve(resolve(after)) === after) assert(resolve(resolve(resolve(after))) === after) - // Also test resolveURIs with single paths - assert(new URI(Utils.resolveURIs(before)) === new URI(after)) - assert(new URI(Utils.resolveURIs(after)) === new URI(after)) } val rawCwd = System.getProperty("user.dir") val cwd = if (Utils.isWindows) s"/$rawCwd".replace("\\", "/") else rawCwd assertResolves("hdfs:/root/spark.jar", "hdfs:/root/spark.jar") - assertResolves("hdfs:///root/spark.jar#app.jar", "hdfs:/root/spark.jar#app.jar") + assertResolves("hdfs:///root/spark.jar#app.jar", "hdfs:///root/spark.jar#app.jar") assertResolves("spark.jar", s"file:$cwd/spark.jar") assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar#app.jar") assertResolves("path to/file.txt", s"file:$cwd/path%20to/file.txt") @@ -482,20 +480,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assertResolves("C:\\path to\\file.txt", "file:/C:/path%20to/file.txt") } assertResolves("file:/C:/path/to/file.txt", "file:/C:/path/to/file.txt") - assertResolves("file:///C:/path/to/file.txt", "file:/C:/path/to/file.txt") + assertResolves("file:///C:/path/to/file.txt", "file:///C:/path/to/file.txt") assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt") - assertResolves("file:foo", s"file:foo") - assertResolves("file:foo:baby", s"file:foo:baby") + assertResolves("file:foo", "file:foo") + assertResolves("file:foo:baby", "file:foo:baby") } test("resolveURIs with multiple paths") { def assertResolves(before: String, after: String): Unit = { assume(before.split(",").length > 1) - assert(Utils.resolveURIs(before) === after) - assert(Utils.resolveURIs(after) === after) - // Repeated invocations of resolveURIs should yield the same result def resolve(uri: String): String = Utils.resolveURIs(uri) + assert(resolve(before) === after) assert(resolve(after) === after) + // Repeated invocations of resolveURIs should yield the same result assert(resolve(resolve(after)) === after) assert(resolve(resolve(resolve(after))) === after) } @@ -511,6 +508,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") } assertResolves(",jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2") + // Also test resolveURIs with single paths + assertResolves("hdfs:/root/spark.jar", "hdfs:/root/spark.jar") } test("nonLocalPaths") { From c70c38eb930569cafe41066e99c58f735277391d Mon Sep 17 00:00:00 2001 From: Wieland Hoffmann Date: Sat, 3 Jun 2017 10:12:37 +0100 Subject: [PATCH 076/133] [DOCS] Fix a typo in Encoder.clsTag ## What changes were proposed in this pull request? Fixes a typo: `and` -> `an` ## How was this patch tested? Not at all. Author: Wieland Hoffmann Closes #17759 from mineo/patch-1. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 68ea47cedac9a..ccdb6bc5d4b7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -78,7 +78,7 @@ trait Encoder[T] extends Serializable { def schema: StructType /** - * A ClassTag that can be used to construct and Array to contain a collection of `T`. + * A ClassTag that can be used to construct an Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } From dec9aa3b37c01454065a4d8899859991f43d4c66 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 4 Jun 2017 13:43:51 -0700 Subject: [PATCH 077/133] [SPARK-20961][SQL] generalize the dictionary in ColumnVector ## What changes were proposed in this pull request? As the first step of https://issues.apache.org/jira/browse/SPARK-20960 , to make `ColumnVector` public, this PR generalize `ColumnVector.dictionary` to not couple with parquet. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #18183 from cloud-fan/dictionary. --- .../parquet/ParquetDictionary.java | 53 +++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 2 +- .../VectorizedParquetRecordReader.java | 17 +++--- .../execution/vectorized/ColumnVector.java | 15 ++---- .../sql/execution/vectorized/Dictionary.java | 34 ++++++++++++ 5 files changed, 100 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/Dictionary.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java new file mode 100644 index 0000000000000..0930edeb352dc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetDictionary.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.spark.sql.execution.vectorized.Dictionary; + +public final class ParquetDictionary implements Dictionary { + private org.apache.parquet.column.Dictionary dictionary; + + public ParquetDictionary(org.apache.parquet.column.Dictionary dictionary) { + this.dictionary = dictionary; + } + + @Override + public int decodeToInt(int id) { + return dictionary.decodeToInt(id); + } + + @Override + public long decodeToLong(int id) { + return dictionary.decodeToLong(id); + } + + @Override + public float decodeToFloat(int id) { + return dictionary.decodeToFloat(id); + } + + @Override + public double decodeToDouble(int id) { + return dictionary.decodeToDouble(id); + } + + @Override + public byte[] decodeToBinary(int id) { + return dictionary.decodeToBinary(id).getBytes(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 9d641b528723a..fd8db1727212f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -169,7 +169,7 @@ void readBatch(int total, ColumnVector column) throws IOException { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some // non-dictionary encoded values have already been added). - column.setDictionary(dictionary); + column.setDictionary(new ParquetDictionary(dictionary)); } else { decodeDictionaryIds(rowId, num, column, dictionaryIds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f2291..04f8141d66e9d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -154,12 +154,6 @@ public float getProgress() throws IOException, InterruptedException { return (float) rowsReturned / totalRowCount; } - /** - * Returns the ColumnarBatch object that will be used for all rows returned by this reader. - * This object is reused. Calling this enables the vectorized reader. This should be called - * before any calls to nextKeyValue/nextBatch. - */ - // Creates a columnar batch that includes the schema from the data files and the additional // partition columns appended to the end of the batch. // For example, if the data contains two columns, with 2 partition columns: @@ -204,12 +198,17 @@ public void initBatch(StructType partitionColumns, InternalRow partitionValues) initBatch(DEFAULT_MEMORY_MODE, partitionColumns, partitionValues); } + /** + * Returns the ColumnarBatch object that will be used for all rows returned by this reader. + * This object is reused. Calling this enables the vectorized reader. This should be called + * before any calls to nextKeyValue/nextBatch. + */ public ColumnarBatch resultBatch() { if (columnarBatch == null) initBatch(); return columnarBatch; } - /* + /** * Can be called before any rows are returned to enable returning columnar batches directly. */ public void enableReturningBatches() { @@ -237,9 +236,7 @@ public boolean nextBatch() throws IOException { } private void initializeInternal() throws IOException, UnsupportedOperationException { - /** - * Check that the requested schema is supported. - */ + // Check that the requested schema is supported. missingColumns = new boolean[requestedSchema.getFieldCount()]; for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { Type t = requestedSchema.getFields().get(i); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ad267ab0c9c47..24260a60197f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -20,8 +20,6 @@ import java.math.BigInteger; import com.google.common.annotations.VisibleForTesting; -import org.apache.parquet.column.Dictionary; -import org.apache.parquet.io.api.Binary; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; @@ -313,8 +311,8 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { } /** - * Ensures that there is enough storage to store capcity elements. That is, the put() APIs - * must work for all rowIds < capcity. + * Ensures that there is enough storage to store capacity elements. That is, the put() APIs + * must work for all rowIds < capacity. */ protected abstract void reserveInternal(int capacity); @@ -479,7 +477,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { /** * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - * src should contain `count` doubles written as ieee format. */ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); @@ -506,7 +503,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { /** * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - * src should contain `count` doubles written as ieee format. */ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); @@ -628,8 +624,8 @@ public final UTF8String getUTF8String(int rowId) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); } else { - Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - return UTF8String.fromBytes(v.getBytes()); + byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); + return UTF8String.fromBytes(bytes); } } @@ -643,8 +639,7 @@ public final byte[] getBinary(int rowId) { System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); return bytes; } else { - Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - return v.getBytes(); + return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/Dictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/Dictionary.java new file mode 100644 index 0000000000000..c698168b4c278 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/Dictionary.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +/** + * The interface for dictionary in ColumnVector to decode dictionary encoded values. + */ +public interface Dictionary { + + int decodeToInt(int id); + + long decodeToLong(int id); + + float decodeToFloat(int id); + + double decodeToDouble(int id); + + byte[] decodeToBinary(int id); +} From 2d39711b052b6a5b7a57951ea0d1514a8cf86c47 Mon Sep 17 00:00:00 2001 From: liupengcheng Date: Mon, 5 Jun 2017 10:23:04 +0100 Subject: [PATCH 078/133] [SPARK-20945] Fix TID key not found in TaskSchedulerImpl ## What changes were proposed in this pull request? This pull request fix the TaskScheulerImpl bug in some condition. Detail see: https://issues.apache.org/jira/browse/SPARK-20945 (Please fill in changes proposed in this fix) ## How was this patch tested? manual tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: liupengcheng Author: PengchengLiu Closes #18171 from liupc/Fix-tid-key-not-found-in-TaskSchedulerImpl. --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1b6bc9139f9c9..f3033e28b47d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -240,8 +240,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // 2. The task set manager has been created but no tasks has been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") + taskIdToExecutorId.get(tid).foreach(execId => + backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) From 98b5ccd32b909cccc38899efa923ca425b116744 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 5 Jun 2017 10:25:09 +0100 Subject: [PATCH 079/133] [SPARK-20930][ML] Destroy broadcasted centers after computing cost in KMeans ## What changes were proposed in this pull request? Destroy broadcasted centers after computing cost ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #18152 from zhengruifeng/destroy_kmeans_model. --- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 5 ++++- .../scala/org/apache/spark/mllib/clustering/LDAModel.scala | 4 ++-- .../apache/spark/mllib/optimization/GradientDescent.scala | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index df2a9c0dd5094..3ad08c46d204d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -85,7 +85,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data + .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + bcCentersWithNorm.destroy(blocking = false) + cost } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 663f63c25a940..4ab420058f33d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -320,6 +320,7 @@ class LocalLDAModel private[spark] ( docBound }.sum() + ElogbetaBc.destroy(blocking = false) // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] @@ -372,7 +373,6 @@ class LocalLDAModel private[spark] ( */ private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) - val expElogbetaBc = sc.broadcast(expElogbeta) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k @@ -383,7 +383,7 @@ class LocalLDAModel private[spark] ( } else { val (gamma, _, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, - expElogbetaBc.value, + expElogbeta, docConcentrationBrz, gammaShape, k) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 07a67a9e719db..593cdd602fafc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -246,6 +246,7 @@ object GradientDescent extends Logging { // c: (grad, loss, count) (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) + bcWeights.destroy(blocking = false) if (miniBatchSize > 0) { /** From 1665b5f724b486068a62c9c72dfd7ed76807c1b3 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 5 Jun 2017 10:32:17 +0100 Subject: [PATCH 080/133] [SPARK-19762][ML] Hierarchy for consolidating ML aggregator/loss code ## What changes were proposed in this pull request? JIRA: [SPARK-19762](https://issues.apache.org/jira/browse/SPARK-19762) The larger changes in this patch are: * Adds a `DifferentiableLossAggregator` trait which is intended to be used as a common parent trait to all Spark ML aggregator classes. It factors out the common methods: `merge, gradient, loss, weight` from the aggregator subclasses. * Adds a `RDDLossFunction` which is intended to be the only implementation of Breeze's `DiffFunction` necessary in Spark ML, and can be used by all other algorithms. It takes the aggregator type as a type parameter, and maps the aggregator over an RDD. It additionally takes in a optional regularization loss function for applying the differentiable part of regularization. * Factors out the regularization from the data part of the cost function, and treats regularization as a separate independent cost function which can be evaluated and added to the data cost function. * Changes `LinearRegression` to use this new hierarchy as a proof of concept. * Adds the following new namespaces `o.a.s.ml.optim.loss` and `o.a.s.ml.optim.aggregator` Also note that none of these are public-facing changes. All of these classes are internal to Spark ML and remain that way. **NOTE: The large majority of the "lines added" and "lines deleted" are simply code moving around or unit tests.** BTW, I also converted LinearSVC to this framework as a way to prove that this new hierarchy is flexible enough for the other algorithms, but I backed those changes out because the PR is large enough as is. ## How was this patch tested? Test suites are added for the new components, and some test suites are also added to provide coverage where there wasn't any before. * DifferentiablLossAggregatorSuite * LeastSquaresAggregatorSuite * RDDLossFunctionSuite * DifferentiableRegularizationSuite Below are some performance testing numbers. Run on a 6 node virtual cluster with 44 cores and ~110G RAM, the dataset size is about 37G. These are not "large-scale" tests, but we really want to just make sure the iteration times don't increase with this patch. Notably we are doing the regularization a bit differently than before, but that should cost very little. I think there's very little risk otherwise, and these numbers don't show a difference. Of course I'm happy to add more tests as we think it's necessary, but I think the patch is ready for review now. **Note:** timings are best of 3 runs. | | numFeatures | numPoints | maxIter | regParam | elasticNetParam | SPARK-19762 (sec) | master (sec) | |----|---------------|-------------|-----------|------------|-------------------|---------------------|----------------| | 0 | 5000 | 1e+06 | 30 | 0 | 0 | 129.594 | 131.153 | | 1 | 5000 | 1e+06 | 30 | 0.1 | 0 | 135.54 | 136.327 | | 2 | 5000 | 1e+06 | 30 | 0.01 | 0.5 | 135.148 | 129.771 | | 3 | 50000 | 100000 | 30 | 0 | 0 | 145.764 | 144.096 | ## Follow ups If this design is accepted, we will convert the other ML algorithms that use this aggregator pattern to this new hierarchy in follow up PRs. Author: sethah Author: sethah Closes #17094 from sethah/ml_aggregators. --- .../DifferentiableLossAggregator.scala | 88 +++++ .../aggregator/LeastSquaresAggregator.scala | 224 ++++++++++++ .../loss/DifferentiableRegularization.scala | 71 ++++ .../spark/ml/optim/loss/RDDLossFunction.scala | 72 ++++ .../ml/regression/LinearRegression.scala | 327 +----------------- .../DifferentiableLossAggregatorSuite.scala | 160 +++++++++ .../LeastSquaresAggregatorSuite.scala | 157 +++++++++ .../DifferentiableRegularizationSuite.scala | 61 ++++ .../ml/optim/loss/RDDLossFunctionSuite.scala | 83 +++++ 9 files changed, 930 insertions(+), 313 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala new file mode 100644 index 0000000000000..403c28ff732f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * A parent trait for aggregators used in fitting MLlib models. This parent trait implements + * some of the common code shared between concrete instances of aggregators. Subclasses of this + * aggregator need only implement the `add` method. + * + * @tparam Datum The type of the instances added to the aggregator to update the loss and gradient. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]]. Classes that subclass this + * type need to use this parameter to specify the concrete type of the aggregator. + */ +private[ml] trait DifferentiableLossAggregator[ + Datum, + Agg <: DifferentiableLossAggregator[Datum, Agg]] extends Serializable { + + self: Agg => // enforce classes that extend this to be the same type as `Agg` + + protected var weightSum: Double = 0.0 + protected var lossSum: Double = 0.0 + + /** The dimension of the gradient array. */ + protected val dim: Int + + /** Array of gradient values that are mutated when new instances are added to the aggregator. */ + protected lazy val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** Add a single data point to this aggregator. */ + def add(instance: Datum): Agg + + /** Merge two aggregators. The `this` object will be modified in place and returned. */ + def merge(other: Agg): Agg = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"${getClass.getSimpleName}. Expecting $dim but got ${other.dim}.") + + if (other.weightSum != 0) { + weightSum += other.weightSum + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + /** The current weighted averaged gradient. */ + def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + val result = Vectors.dense(gradientSumArray.clone()) + BLAS.scal(1.0 / weightSum, result) + result + } + + /** Weighted count of instances in this aggregator. */ + def weight: Double = weightSum + + /** The current loss value of this aggregator. */ + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + lossSum / weightSum + } + +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala new file mode 100644 index 0000000000000..1994b0e40e520 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in an online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the coefficients in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * When training with intercept enabled, + * The objective function in the scaled space is given by + * + *
+ * $$ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * $$ + *
+ * + * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, + * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. + * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead + * of the respective means. + * + * This can be rewritten as + * + *
+ * $$ + * \begin{align} + * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 \\ + * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * \end{align} + * $$ + *
+ * + * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is + * + *
+ * $$ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * $$ + *
+ * + * and diff is + * + *
+ * $$ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * $$ + *
+ * + * Note that the effective coefficients and offset don't depend on training dataset, + * so they can be precomputed. + * + * Now, the first derivative of the objective function in scaled space is + * + *
+ * $$ + * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * $$ + *
+ * + * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * + * + *
+ * $$ + * \begin{align} + * \frac{\partial L}{\partial w_i} &= + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * \end{align} + * $$ + *
+ * + * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * + *
+ * $$ + * \begin{align} + * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) + * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ + * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ + * &= 0 + * \end{align} + * $$ + *
+ * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * + *
+ * $$ + * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * $$ + *
+ * + * @note The constructor is curried, since the cost function will repeatedly create new versions + * of this class for different coefficient vectors. + * + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param fitIntercept Whether to fit an intercept term. + * @param bcFeaturesStd The broadcast standard deviation values of the features. + * @param bcFeaturesMean The broadcast mean values of the features. + * @param bcCoefficients The broadcast coefficients corresponding to the features. + */ +private[ml] class LeastSquaresAggregator( + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + bcFeaturesStd: Broadcast[Array[Double]], + bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[Instance, LeastSquaresAggregator] { + require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " + + s"deviation to be positive.") + + private val numFeatures = bcFeaturesStd.value.length + protected override val dim: Int = numFeatures + // make transient so we do not serialize between aggregation stages + @transient private lazy val featuresStd = bcFeaturesStd.value + @transient private lazy val effectiveCoefAndOffset = { + val coefficientsArray = bcCoefficients.value.toArray.clone() + val featuresMean = bcFeaturesMean.value + var sum = 0.0 + var i = 0 + val len = coefficientsArray.length + while (i < len) { + if (featuresStd(i) != 0.0) { + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) + } else { + coefficientsArray(i) = 0.0 + } + i += 1 + } + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (Vectors.dense(coefficientsArray), offset) + } + // do not use tuple assignment above because it will circumvent the @transient tag + @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 + @transient private lazy val offset = effectiveCoefAndOffset._2 + + /** + * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This LeastSquaresAggregator object. + */ + def add(instance: Instance): LeastSquaresAggregator = { + instance match { case Instance(label, weight, features) => + require(numFeatures == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $numFeatures but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + + val diff = BLAS.dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + val localFeaturesStd = featuresStd + features.foreachActive { (index, value) => + val fStd = localFeaturesStd(index) + if (fStd != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / fStd + } + } + lossSum += weight * diff * diff / 2.0 + } + weightSum += weight + this + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala new file mode 100644 index 0000000000000..118c0ebfa513e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import breeze.optimize.DiffFunction + +/** + * A Breeze diff function which represents a cost function for differentiable regularization + * of parameters. e.g. L2 regularization: 1 / 2 regParam * beta dot beta + * + * @tparam T The type of the coefficients being regularized. + */ +private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] { + + /** Magnitude of the regularization penalty. */ + def regParam: Double + +} + +/** + * A Breeze diff function for computing the L2 regularized loss and gradient of an array of + * coefficients. + * + * @param regParam The magnitude of the regularization. + * @param shouldApply A function (Int => Boolean) indicating whether a given index should have + * regularization applied to it. + * @param featuresStd Option indicating whether the regularization should be scaled by the standard + * deviation of the features. + */ +private[ml] class L2Regularization( + val regParam: Double, + shouldApply: Int => Boolean, + featuresStd: Option[Array[Double]]) extends DifferentiableRegularization[Array[Double]] { + + override def calculate(coefficients: Array[Double]): (Double, Array[Double]) = { + var sum = 0.0 + val gradient = new Array[Double](coefficients.length) + coefficients.indices.filter(shouldApply).foreach { j => + val coef = coefficients(j) + featuresStd match { + case Some(stds) => + val std = stds(j) + if (std != 0.0) { + val temp = coef / (std * std) + sum += coef * temp + gradient(j) = regParam * temp + } else { + 0.0 + } + case None => + sum += coef * coef + gradient(j) = coef * regParam + } + } + (0.5 * sum * regParam, gradient) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala new file mode 100644 index 0000000000000..3b1618eb0b6fe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import scala.reflect.ClassTag + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.DiffFunction + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator +import org.apache.spark.rdd.RDD + +/** + * This class computes the gradient and loss of a differentiable loss function by mapping a + * [[DifferentiableLossAggregator]] over an [[RDD]] of [[Instance]]s. The loss function is the + * sum of the loss computed on a single instance across all points in the RDD. Therefore, the actual + * analytical form of the loss function is specified by the aggregator, which computes each points + * contribution to the overall loss. + * + * A differentiable regularization component can also be added by providing a + * [[DifferentiableRegularization]] loss function. + * + * @param instances + * @param getAggregator A function which gets a new loss aggregator in every tree aggregate step. + * @param regularization An option representing the regularization loss function to apply to the + * coefficients. + * @param aggregationDepth The aggregation depth of the tree aggregation step. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]], representing the concrete type + * of the aggregator. + */ +private[ml] class RDDLossFunction[ + T: ClassTag, + Agg <: DifferentiableLossAggregator[T, Agg]: ClassTag]( + instances: RDD[T], + getAggregator: (Broadcast[Vector] => Agg), + regularization: Option[DifferentiableRegularization[Array[Double]]], + aggregationDepth: Int = 2) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val bcCoefficients = instances.context.broadcast(Vectors.fromBreeze(coefficients)) + val thisAgg = getAggregator(bcCoefficients) + val seqOp = (agg: Agg, x: T) => agg.add(x) + val combOp = (agg1: Agg, agg2: Agg) => agg1.merge(agg2) + val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth) + val gradient = newAgg.gradient + val regLoss = regularization.map { regFun => + val (regLoss, regGradient) = regFun.calculate(coefficients.data) + BLAS.axpy(1.0, Vectors.dense(regGradient), gradient) + regLoss + }.getOrElse(0.0) + bcCoefficients.destroy(blocking = false) + (newAgg.loss + regLoss, gradient.asBreeze.toDenseVector) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index eaad54985229e..db5ac4f14bd3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,19 +20,20 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator +import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -319,8 +320,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), - $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth)) + val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + val regularization = if (effectiveL2RegParam != 0.0) { + val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures + Some(new L2Regularization(effectiveL2RegParam, shouldApply, + if ($(standardization)) None else Some(featuresStd))) + } else { + None + } + val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, + $(aggregationDepth)) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -793,312 +803,3 @@ class LinearRegressionSummary private[regression] ( } -/** - * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, - * as used in linear regression for samples in sparse or dense vector in an online fashion. - * - * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of - * the corresponding joint dataset. - * - * For improving the convergence rate during the optimization process, and also preventing against - * features with very large variances exerting an overly large influence during model training, - * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce - * the condition number, and then trains the model in scaled space but returns the coefficients in - * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf - * - * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache - * the standardized dataset since it will create a lot of overhead. As a result, we perform the - * scaling implicitly when we compute the objective function. The following is the mathematical - * derivation. - * - * Note that we don't deal with intercept by adding bias here, because the intercept - * can be computed using closed form after the coefficients are converged. - * See this discussion for detail. - * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - * - * When training with intercept enabled, - * The objective function in the scaled space is given by - * - *
- * $$ - * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, - * $$ - *
- * - * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, - * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. - * - * If we fitting the intercept disabled (that is forced through 0.0), - * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead - * of the respective means. - * - * This can be rewritten as - * - *
- * $$ - * \begin{align} - * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} - * + \bar{y} / \hat{y}||^2 \\ - * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 - * \end{align} - * $$ - *
- * - * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is - * - *
- * $$ - * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. - * $$ - *
- * - * and diff is - * - *
- * $$ - * \sum_i w_i^\prime x_i - y / \hat{y} + offset - * $$ - *
- * - * Note that the effective coefficients and offset don't depend on training dataset, - * so they can be precomputed. - * - * Now, the first derivative of the objective function in scaled space is - * - *
- * $$ - * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} - * $$ - *
- * - * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not - * an ideal formula when the training dataset is sparse format. - * - * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms - * in the end by keeping the sum of diff. The first derivative of total - * objective function from all the samples is - * - * - *
- * $$ - * \begin{align} - * \frac{\partial L}{\partial w_i} &= - * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) - * \end{align} - * $$ - *
- * - * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ - * - * A simple math can show that diffSum is actually zero, so we don't even - * need to add the correction terms in the end. From the definition of diff, - * - *
- * $$ - * \begin{align} - * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) - * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ - * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ - * &= 0 - * \end{align} - * $$ - *
- * - * As a result, the first derivative of the total objective function only depends on - * the training dataset, which can be easily computed in distributed fashion, and is - * sparse format friendly. - * - *
- * $$ - * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - * $$ - *
- * - * @param bcCoefficients The broadcast coefficients corresponding to the features. - * @param labelStd The standard deviation value of the label. - * @param labelMean The mean value of the label. - * @param fitIntercept Whether to fit an intercept term. - * @param bcFeaturesStd The broadcast standard deviation values of the features. - * @param bcFeaturesMean The broadcast mean values of the features. - */ -private class LeastSquaresAggregator( - bcCoefficients: Broadcast[Vector], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]]) extends Serializable { - - private var totalCnt: Long = 0L - private var weightSum: Double = 0.0 - private var lossSum = 0.0 - - private val dim = bcCoefficients.value.size - // make transient so we do not serialize between aggregation stages - @transient private lazy val featuresStd = bcFeaturesStd.value - @transient private lazy val effectiveCoefAndOffset = { - val coefficientsArray = bcCoefficients.value.toArray.clone() - val featuresMean = bcFeaturesMean.value - var sum = 0.0 - var i = 0 - val len = coefficientsArray.length - while (i < len) { - if (featuresStd(i) != 0.0) { - coefficientsArray(i) /= featuresStd(i) - sum += coefficientsArray(i) * featuresMean(i) - } else { - coefficientsArray(i) = 0.0 - } - i += 1 - } - val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 - (Vectors.dense(coefficientsArray), offset) - } - // do not use tuple assignment above because it will circumvent the @transient tag - @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 - @transient private lazy val offset = effectiveCoefAndOffset._2 - - private lazy val gradientSumArray = Array.ofDim[Double](dim) - - /** - * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * - * @param instance The instance of data point to be added. - * @return This LeastSquaresAggregator object. - */ - def add(instance: Instance): this.type = { - instance match { case Instance(label, weight, features) => - - if (weight == 0.0) return this - - val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset - - if (diff != 0) { - val localGradientSumArray = gradientSumArray - val localFeaturesStd = featuresStd - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += weight * diff * value / localFeaturesStd(index) - } - } - lossSum += weight * diff * diff / 2.0 - } - - totalCnt += 1 - weightSum += weight - this - } - } - - /** - * Merge another LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * (Note that it's in place merging; as a result, `this` object will be modified.) - * - * @param other The other LeastSquaresAggregator to be merged. - * @return This LeastSquaresAggregator object. - */ - def merge(other: LeastSquaresAggregator): this.type = { - - if (other.weightSum != 0) { - totalCnt += other.totalCnt - weightSum += other.weightSum - lossSum += other.lossSum - - var i = 0 - val localThisGradientSumArray = this.gradientSumArray - val localOtherGradientSumArray = other.gradientSumArray - while (i < dim) { - localThisGradientSumArray(i) += localOtherGradientSumArray(i) - i += 1 - } - } - this - } - - def count: Long = totalCnt - - def loss: Double = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - lossSum / weightSum - } - - def gradient: Vector = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / weightSum, result) - result - } -} - -/** - * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. - * It returns the loss and gradient with L2 regularization at a particular point (coefficients). - * It's used in Breeze's convex optimization routines. - */ -private class LeastSquaresCostFun( - instances: RDD[Instance], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - standardization: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]], - effectiveL2regParam: Double, - aggregationDepth: Int) extends DiffFunction[BDV[Double]] { - - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { - val coeffs = Vectors.fromBreeze(coefficients) - val bcCoeffs = instances.context.broadcast(coeffs) - val localFeaturesStd = bcFeaturesStd.value - - val leastSquaresAggregator = { - val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance) - val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2) - - instances.treeAggregate( - new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd, - bcFeaturesMean))(seqOp, combOp, aggregationDepth) - } - - val totalGradientArray = leastSquaresAggregator.gradient.toArray - bcCoeffs.destroy(blocking = false) - - val regVal = if (effectiveL2regParam == 0.0) { - 0.0 - } else { - var sum = 0.0 - coeffs.foreachActive { (index, value) => - // The following code will compute the loss of the regularization; also - // the gradient of the regularization, and add back to totalGradientArray. - sum += { - if (standardization) { - totalGradientArray(index) += effectiveL2regParam * value - value * value - } else { - if (localFeaturesStd(index) != 0.0) { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - val temp = value / (localFeaturesStd(index) * localFeaturesStd(index)) - totalGradientArray(index) += effectiveL2regParam * temp - value * temp - } else { - 0.0 - } - } - } - } - 0.5 * effectiveL2regParam * sum - } - - (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala new file mode 100644 index 0000000000000..7a4faeb1c10bf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ + +class DifferentiableLossAggregatorSuite extends SparkFunSuite { + + import DifferentiableLossAggregatorSuite.TestAggregator + + private val instances1 = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + private val instances2 = Seq( + Instance(0.2, 0.4, Vectors.dense(0.8, 2.5)), + Instance(0.8, 0.9, Vectors.dense(2.0, 1.3)), + Instance(1.5, 0.2, Vectors.dense(3.0, 0.2)) + ) + + private def assertEqual[T, Agg <: DifferentiableLossAggregator[T, Agg]]( + agg1: DifferentiableLossAggregator[T, Agg], + agg2: DifferentiableLossAggregator[T, Agg]): Unit = { + assert(agg1.weight === agg2.weight) + assert(agg1.loss === agg2.loss) + assert(agg1.gradient === agg2.gradient) + } + + test("empty aggregator") { + val numFeatures = 5 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + withClue("cannot get loss for empty aggregator") { + intercept[IllegalArgumentException] { + agg.loss + } + } + withClue("cannot get gradient for empty aggregator") { + intercept[IllegalArgumentException] { + agg.gradient + } + } + } + + test("aggregator initialization") { + val numFeatures = 3 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + agg.add(Instance(1.0, 0.3, Vectors.dense(Array.fill(numFeatures)(1.0)))) + assert(agg.gradient.size === 3) + assert(agg.weight === 0.3) + } + + test("merge aggregators") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg1 = new TestAggregator(2)(coefficients) + val agg2 = new TestAggregator(2)(coefficients) + val aggBadDim = new TestAggregator(1)(Vectors.dense(0.5)) + aggBadDim.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + instances1.foreach(agg1.add) + + // merge incompatible aggregators + withClue("cannot merge aggregators with different dimensions") { + intercept[IllegalArgumentException] { + agg1.merge(aggBadDim) + } + } + + // merge empty other + val mergedEmptyOther = agg1.merge(agg2) + assertEqual(mergedEmptyOther, agg1) + assert(mergedEmptyOther === agg1) + + // merge empty this + val agg3 = new TestAggregator(2)(coefficients) + val mergedEmptyThis = agg3.merge(agg1) + assertEqual(mergedEmptyThis, agg1) + assert(mergedEmptyThis !== agg1) + + instances2.foreach(agg2.add) + val (loss1, weight1, grad1) = (agg1.loss, agg1.weight, agg1.gradient) + val (loss2, weight2, grad2) = (agg2.loss, agg2.weight, agg2.gradient) + val merged = agg1.merge(agg2) + + // check pointers are equal + assert(merged === agg1) + + // loss should be weighted average of the two individual losses + assert(merged.loss === (loss1 * weight1 + loss2 * weight2) / (weight1 + weight2)) + assert(merged.weight === weight1 + weight2) + + // gradient should be weighted average of individual gradients + val addedGradients = Vectors.dense(grad1.toArray.clone()) + BLAS.scal(weight1, addedGradients) + BLAS.axpy(weight2, grad2, addedGradients) + BLAS.scal(1 / (weight1 + weight2), addedGradients) + assert(merged.gradient === addedGradients) + } + + test("loss, gradient, weight") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg = new TestAggregator(2)(coefficients) + instances1.foreach(agg.add) + val errors = instances1.map { case Instance(label, _, features) => + label - BLAS.dot(features, coefficients) + } + val expectedLoss = errors.zip(instances1).map { case (error: Double, instance: Instance) => + instance.weight * error * error / 2.0 + } + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances1).foreach { case (error, instance) => + BLAS.axpy(instance.weight * error, instance.features, expectedGradient) + } + BLAS.scal(1.0 / agg.weight, expectedGradient) + val weightSum = instances1.map(_.weight).sum + + assert(agg.weight ~== weightSum relTol 1e-5) + assert(agg.loss ~== expectedLoss.sum / weightSum relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } +} + +object DifferentiableLossAggregatorSuite { + /** + * Dummy aggregator that represents least squares cost with no intercept. + */ + class TestAggregator(numFeatures: Int)(coefficients: Vector) + extends DifferentiableLossAggregator[Instance, TestAggregator] { + + protected override val dim: Int = numFeatures + + override def add(instance: Instance): TestAggregator = { + val error = instance.label - BLAS.dot(coefficients, instance.features) + weightSum += instance.weight + lossSum += instance.weight * error * error / 2.0 + (0 until dim).foreach { j => + gradientSumArray(j) += instance.weight * error * instance.features(j) + } + this + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala new file mode 100644 index 0000000000000..d1cb0d380e7a5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantLabel: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantLabel = Array( + Instance(1.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(1.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + } + + /** Get feature and label summarizers for provided data. */ + def getSummarizers( + instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.aggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer + )(seqOp, combOp) + } + + /** Get summary statistics for some data and create a new LeastSquaresAggregator. */ + def getNewAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): LeastSquaresAggregator = { + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val featuresMean = featuresSummarizer.mean + val bcFeaturesMean = spark.sparkContext.broadcast(featuresMean.toArray) + val bcCoefficients = spark.sparkContext.broadcast(coefficients) + new LeastSquaresAggregator(yStd, yMean, fitIntercept, bcFeaturesStd, + bcFeaturesMean)(bcCoefficients) + } + + test("check sizes") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggIntercept = getNewAggregator(instances, coefficients, fitIntercept = true) + val aggNoIntercept = getNewAggregator(instances, coefficients, fitIntercept = false) + instances.foreach(aggIntercept.add) + instances.foreach(aggNoIntercept.add) + + // least squares agg does not include intercept in its gradient array + assert(aggIntercept.gradient.size === 2) + assert(aggNoIntercept.gradient.size === 2) + } + + test("check correctness") { + /* + Check that the aggregator computes loss/gradient for: + 0.5 * sum_i=1^N ([sum_j=1^D beta_j * ((x_j - x_j,bar) / sigma_j)] - ((y - ybar) / sigma_y))^2 + */ + val coefficients = Vectors.dense(1.0, 2.0) + val numFeatures = coefficients.size + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + + val agg = getNewAggregator(instances, coefficients, fitIntercept = true) + instances.foreach(agg.add) + + // compute (y - pred) analytically + val errors = instances.map { case Instance(l, w, f) => + val scaledFeatures = (0 until numFeatures).map { j => + (f.toArray(j) - featuresMean(j)) / featuresStd(j) + }.toArray + val scaledLabel = (l - yMean) / yStd + BLAS.dot(coefficients, Vectors.dense(scaledFeatures)) - scaledLabel + } + + // compute expected loss sum analytically + val expectedLoss = errors.zip(instances).map { case (error, instance) => + instance.weight * error * error / 2.0 + } + + // compute gradient analytically from instances + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances).foreach { case (error, instance) => + val scaledFeatures = (0 until numFeatures).map { j => + instance.weight * instance.features.toArray(j) / featuresStd(j) + }.toArray + BLAS.axpy(error, Vectors.dense(scaledFeatures), expectedGradient) + } + + val weightSum = instances.map(_.weight).sum + BLAS.scal(1.0 / weightSum, expectedGradient) + assert(agg.loss ~== (expectedLoss.sum / weightSum) relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } + + test("check with zero standard deviation") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefficients, + fitIntercept = true) + instances.foreach(aggConstantFeature.add) + // constant features should not affect gradient + assert(aggConstantFeature.gradient(0) === 0.0) + + withClue("LeastSquaresAggregator does not support zero standard deviation of the label") { + intercept[IllegalArgumentException] { + getNewAggregator(instancesConstantLabel, coefficients, fitIntercept = true) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala new file mode 100644 index 0000000000000..0794417a8d4bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite + +class DifferentiableRegularizationSuite extends SparkFunSuite { + + test("L2 regularization") { + val shouldApply = (_: Int) => true + val regParam = 0.3 + val coefficients = Array(1.0, 3.0, -2.0) + val numFeatures = coefficients.size + + // check without features standard + val regFun = new L2Regularization(regParam, shouldApply, None) + val (loss, grad) = regFun.calculate(coefficients) + assert(loss === 0.5 * regParam * coefficients.map(x => x * x).sum) + assert(grad === coefficients.map(_ * regParam)) + + // check with features standard + val featuresStd = Array(0.1, 1.1, 0.5) + val regFunStd = new L2Regularization(regParam, shouldApply, Some(featuresStd)) + val (lossStd, gradStd) = regFunStd.calculate(coefficients) + val expectedLossStd = 0.5 * regParam * (0 until numFeatures).map { j => + coefficients(j) * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.sum + val expectedGradientStd = (0 until numFeatures).map { j => + regParam * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.toArray + assert(lossStd === expectedLossStd) + assert(gradStd === expectedGradientStd) + + // check should apply + val shouldApply2 = (i: Int) => i == 1 + val regFunApply = new L2Regularization(regParam, shouldApply2, None) + val (lossApply, gradApply) = regFunApply.calculate(coefficients) + assert(lossApply === 0.5 * regParam * coefficients(1) * coefficients(1)) + assert(gradApply === Array(0.0, coefficients(1) * regParam, 0.0)) + + // check with zero features standard + val featuresStdZero = Array(0.1, 0.0, 0.5) + val regFunStdZero = new L2Regularization(regParam, shouldApply, Some(featuresStdZero)) + val (_, gradStdZero) = regFunStdZero.calculate(coefficients) + assert(gradStdZero(1) == 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala new file mode 100644 index 0000000000000..cd5cebee5f7b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregatorSuite.TestAggregator +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD + +class RDDLossFunctionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = sc.parallelize(Seq( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + )) + } + + test("regularization") { + val coefficients = Vectors.dense(0.5, -0.1) + val regLossFun = new L2Regularization(0.1, (_: Int) => true, None) + val getAgg = (bvec: Broadcast[Vector]) => new TestAggregator(2)(bvec.value) + val lossNoReg = new RDDLossFunction(instances, getAgg, None) + val lossWithReg = new RDDLossFunction(instances, getAgg, Some(regLossFun)) + + val (loss1, grad1) = lossNoReg.calculate(coefficients.asBreeze.toDenseVector) + val (regLoss, regGrad) = regLossFun.calculate(coefficients.toArray) + val (loss2, grad2) = lossWithReg.calculate(coefficients.asBreeze.toDenseVector) + + BLAS.axpy(1.0, Vectors.fromBreeze(grad1), Vectors.dense(regGrad)) + assert(Vectors.dense(regGrad) ~== Vectors.fromBreeze(grad2) relTol 1e-5) + assert(loss1 + regLoss === loss2) + } + + test("empty RDD") { + val rdd = sc.parallelize(Seq.empty[Instance]) + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(rdd, getAgg, None) + withClue("cannot calculate cost for empty dataset") { + intercept[IllegalArgumentException]{ + lossFun.calculate(coefficients.asBreeze.toDenseVector) + } + } + } + + test("versus aggregating on an iterable") { + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(instances, getAgg, None) + val (loss, grad) = lossFun.calculate(coefficients.asBreeze.toDenseVector) + + // just map the aggregator over the instances array + val agg = new TestAggregator(2)(coefficients) + instances.collect().foreach(agg.add) + + assert(loss === agg.loss) + assert(Vectors.fromBreeze(grad) === agg.gradient) + } + +} From 06c0544113ba77857c5cb1bbf94dcaf21d0b01af Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 5 Jun 2017 11:06:50 -0700 Subject: [PATCH 081/133] [SPARK-20981][SPARKSUBMIT] Add new configuration spark.jars.repositories as equivalence of --repositories ## What changes were proposed in this pull request? In our use case of launching Spark applications via REST APIs (Livy), there's no way for user to specify command line arguments, all Spark configurations are set through configurations map. For "--repositories" because there's no equivalent Spark configuration, so we cannot specify the custom repository through configuration. So here propose to add "--repositories" equivalent configuration in Spark. ## How was this patch tested? New UT added. Author: jerryshao Closes #18201 from jerryshao/SPARK-20981. --- .../spark/deploy/SparkSubmitArguments.scala | 2 ++ .../spark/deploy/SparkSubmitSuite.scala | 20 +++++++++++++++++++ docs/configuration.md | 13 ++++++++++-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 5100a17006e24..b76a3d2bea4c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -187,6 +187,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull + repositories = Option(repositories) + .orElse(sparkProperties.get("spark.jars.repositories")).orNull deployMode = Option(deployMode) .orElse(sparkProperties.get("spark.submit.deployMode")) .orElse(env.get("DEPLOY_MODE")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 6e9721c45931a..de719990cf47a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -477,6 +477,26 @@ class SparkSubmitSuite } } + test("includes jars passed through spark.jars.packages and spark.jars.repositories") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") + // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. + IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.jars.packages=my.great.lib:mylib:0.1,my.great.dep:mylib:0.1", + "--conf", s"spark.jars.repositories=$repo", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + unusedJar.toString, + "my.great.lib.MyLib", "my.great.dep.MyLib") + runSparkSubmit(args) + } + } + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { diff --git a/docs/configuration.md b/docs/configuration.md index 0771e36f80b50..f777811a93f62 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -474,10 +474,19 @@ Apart from these, the following properties are also available, and may be useful Path to an Ivy settings file to customize resolution of jars specified using spark.jars.packages instead of the built-in defaults, such as maven central. Additional repositories given by the command-line - option --repositories will also be included. Useful for allowing Spark to resolve artifacts from behind - a firewall e.g. via an in-house artifact server like Artifactory. Details on the settings file format can be + option --repositories or spark.jars.repositories will also be included. + Useful for allowing Spark to resolve artifacts from behind a firewall e.g. via an in-house + artifact server like Artifactory. Details on the settings file format can be found at http://ant.apache.org/ivy/history/latest-milestone/settings.html + + + spark.jars.repositories + + + Comma-separated list of additional remote repositories to search for the maven coordinates + given with --packages or spark.jars.packages. + spark.pyspark.driver.python From bc537e40ade0658aae7c6b5ddafb4cc038bdae2b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 5 Jun 2017 14:34:10 -0700 Subject: [PATCH 082/133] [SPARK-20957][SS][TESTS] Fix o.a.s.sql.streaming.StreamingQueryManagerSuite listing ## What changes were proposed in this pull request? When stopping StreamingQuery, StreamExecution will set `streamDeathCause` then notify StreamingQueryManager to remove this query. So it's possible that when `q2.exception.isDefined` returns `true`, StreamingQueryManager's active list still has `q2`. This PR just puts the checks into `eventually` to fix the flaky test. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #18180 from zsxwing/SPARK-20957. --- .../spark/sql/streaming/StreamingQueryManagerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index b49efa6890236..2986b7f1eecfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -78,9 +78,9 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { eventually(Timeout(streamingTimeout)) { require(!q2.isActive) require(q2.exception.isDefined) + assert(spark.streams.get(q2.id) === null) + assert(spark.streams.active.toSet === Set(q3)) } - assert(spark.streams.get(q2.id) === null) - assert(spark.streams.active.toSet === Set(q3)) } } From 88a23d3de046c5f22417a0bd679119b876b15568 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 5 Jun 2017 17:48:28 -0700 Subject: [PATCH 083/133] [SPARK-20991][SQL] BROADCAST_TIMEOUT conf should be a TimeoutConf ## What changes were proposed in this pull request? The construction of BROADCAST_TIMEOUT conf should take the TimeUnit argument as a TimeoutConf. Author: Feng Liu Closes #18208 from liufengdb/fix_timeout. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54bee02e44e43..3ea808926e10b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -352,7 +352,7 @@ object SQLConf { val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") - .intConf + .timeConf(TimeUnit.SECONDS) .createWithDefault(5 * 60) // This is only used for the thriftserver @@ -991,7 +991,7 @@ class SQLConf extends Serializable with Logging { def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) + def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) From 44de108d743ddbec11905f0fc86fb3fccdbac90e Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 6 Jun 2017 11:14:39 +0100 Subject: [PATCH 084/133] [SPARK-20985] Stop SparkContext using LocalSparkContext.withSpark ## What changes were proposed in this pull request? SparkContext should always be stopped after using, thus other tests won't complain that there's only one `SparkContext` can exist. Author: jinxing Closes #18204 from jinxing64/SPARK-20985. --- .../org/apache/spark/MapOutputTrackerSuite.scala | 7 ++----- .../KryoSerializerResizableOutputSuite.scala | 14 +++++++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 71bedda5ac894..4fe5c5e4fee4a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.mockito.Mockito._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.LocalSparkContext._ import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException @@ -245,8 +246,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize // needs TorrentBroadcast so need a SparkContext - val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) - try { + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val rpcEnv = sc.env.rpcEnv val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) @@ -271,9 +271,6 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(1 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterShuffle(20) assert(0 == masterTracker.getNumCachedSerializedBroadcast) - - } finally { - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 21251f0b93760..cf01f79f49091 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.LocalSparkContext +import org.apache.spark.LocalSparkContext._ import org.apache.spark.SparkContext import org.apache.spark.SparkException @@ -32,9 +32,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "1m") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect()) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + intercept[SparkException](sc.parallelize(x).collect()) + } } test("kryo with resizable output buffer should succeed on large array") { @@ -42,8 +42,8 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryoserializer.buffer", "1m") conf.set("spark.kryoserializer.buffer.max", "2m") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect() === x) - LocalSparkContext.stop(sc) + withSpark(new SparkContext("local", "test", conf)) { sc => + assert(sc.parallelize(x).collect() === x) + } } } From b61a401da80860d5137176ec796d77aae096e470 Mon Sep 17 00:00:00 2001 From: Reza Safi Date: Mon, 5 Jun 2017 13:19:01 -0700 Subject: [PATCH 085/133] [SPARK-20926][SQL] Removing exposures to guava library caused by directly accessing SessionCatalog's tableRelationCache There could be test failures because DataStorageStrategy, HiveMetastoreCatalog and also HiveSchemaInferenceSuite were exposed to guava library by directly accessing SessionCatalog's tableRelationCacheg. These failures occur when guava shading is in place. ## What changes were proposed in this pull request? This change removes those guava exposures by introducing new methods in SessionCatalog and also changing DataStorageStrategy, HiveMetastoreCatalog and HiveSchemaInferenceSuite so that they use those proxy methods. ## How was this patch tested? Unit tests passed after applying these changes. Author: Reza Safi Closes #18148 from rezasafi/branch-2.2. (cherry picked from commit 1388fdd70733b92488f32bb31d2eb2ea4155ee62) --- .../sql/catalyst/catalog/SessionCatalog.scala | 31 ++++++++++++++++--- .../datasources/DataSourceStrategy.scala | 4 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +++++----- .../sql/hive/HiveSchemaInferenceSuite.scala | 2 +- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a78440df4f3e1..57006bfaf9b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Locale +import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -125,14 +126,36 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } - /** - * A cache of qualified table names to table relation plans. - */ - val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() } + /** This method provides a way to get a cached plan. */ + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + tableRelationCache.get(t, c) + } + + /** This method provides a way to get a cached plan if the key exists. */ + def getCachedTable(key: QualifiedTableName): LogicalPlan = { + tableRelationCache.getIfPresent(key) + } + + /** This method provides a way to cache a plan. */ + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { + tableRelationCache.put(t, l) + } + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedTable(key: QualifiedTableName): Unit = { + tableRelationCache.invalidate(key) + } + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedTables(): Unit = { + tableRelationCache.invalidateAll() + } + /** * This method is used to make the given path qualified before we * store this path in the underlying external catalog. So, when a path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 21d75a404911b..e05a8d5f02bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -215,9 +215,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) - val cache = sparkSession.sessionState.catalog.tableRelationCache + val catalogProxy = sparkSession.sessionState.catalog - val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val plan = catalogProxy.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9dd8279efc1f4..ff5afc8e3ce05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState - private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + private def catalogProxy = sparkSession.sessionState.catalog import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ @@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val key = QualifiedTableName( table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) - tableRelationCache.getIfPresent(key) + catalogProxy.getCachedTable(key) } private def getCached( @@ -71,7 +71,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log expectedFileFormat: Class[_ <: FileFormat], partitionSchema: Option[StructType]): Option[LogicalRelation] = { - tableRelationCache.getIfPresent(tableIdentifier) match { + catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass @@ -92,21 +92,21 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(logical) } else { // If the cached relation is not updated, we invalidate it right away. - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case _ => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case other => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a $other from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } } @@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormat = fileFormat, options = options)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } @@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log className = fileType).resolveRelation(), table = updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index b3a06045b5fd4..d271acc63de08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite override def afterEach(): Unit = { super.afterEach() - spark.sessionState.catalog.tableRelationCache.invalidateAll() + spark.sessionState.catalog.invalidateAllCachedTables() FileStatusCache.resetForTesting() } From 0cba495120bc5a889ceeb8d66713a053d7561be2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 6 Jun 2017 13:39:10 -0500 Subject: [PATCH 086/133] [SPARK-20641][CORE] Add key-value store abstraction and LevelDB implementation. This change adds an abstraction and LevelDB implementation for a key-value store that will be used to store UI and SHS data. The interface is described in KVStore.java (see javadoc). Specifics of the LevelDB implementation are discussed in the javadocs of both LevelDB.java and LevelDBTypeInfo.java. Included also are a few small benchmarks just to get some idea of latency. Because they're too slow for regular unit test runs, they're disabled by default. Tested with the included unit tests, and also as part of the overall feature implementation (including running SHS with hundreds of apps). Author: Marcelo Vanzin Closes #17902 from vanzin/shs-ng/M1. --- common/kvstore/pom.xml | 101 ++++ .../org/apache/spark/kvstore/KVIndex.java | 82 +++ .../org/apache/spark/kvstore/KVStore.java | 129 +++++ .../apache/spark/kvstore/KVStoreIterator.java | 47 ++ .../spark/kvstore/KVStoreSerializer.java | 86 +++ .../org/apache/spark/kvstore/KVStoreView.java | 126 +++++ .../org/apache/spark/kvstore/KVTypeInfo.java | 156 ++++++ .../org/apache/spark/kvstore/LevelDB.java | 308 +++++++++++ .../apache/spark/kvstore/LevelDBIterator.java | 278 ++++++++++ .../apache/spark/kvstore/LevelDBTypeInfo.java | 516 ++++++++++++++++++ .../UnsupportedStoreVersionException.java | 27 + .../org/apache/spark/kvstore/CustomType1.java | 63 +++ .../apache/spark/kvstore/DBIteratorSuite.java | 506 +++++++++++++++++ .../spark/kvstore/LevelDBBenchmark.java | 280 ++++++++++ .../spark/kvstore/LevelDBIteratorSuite.java | 48 ++ .../apache/spark/kvstore/LevelDBSuite.java | 312 +++++++++++ .../spark/kvstore/LevelDBTypeInfoSuite.java | 207 +++++++ .../src/test/resources/log4j.properties | 27 + pom.xml | 11 + project/SparkBuild.scala | 6 +- 20 files changed, 3313 insertions(+), 3 deletions(-) create mode 100644 common/kvstore/pom.xml create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java create mode 100644 common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java create mode 100644 common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java create mode 100644 common/kvstore/src/test/resources/log4j.properties diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml new file mode 100644 index 0000000000000..d00cf2788b964 --- /dev/null +++ b/common/kvstore/pom.xml @@ -0,0 +1,101 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../pom.xml + + + spark-kvstore_2.11 + jar + Spark Project Local DB + http://spark.apache.org/ + + kvstore + + + + + com.google.guava + guava + + + org.fusesource.leveldbjni + leveldbjni-all + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + commons-io + commons-io + test + + + log4j + log4j + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + io.dropwizard.metrics + metrics-core + test + + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java new file mode 100644 index 0000000000000..8b8899023c938 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVIndex.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Tags a field to be indexed when storing an object. + * + *

+ * Types are required to have a natural index that uniquely identifies instances in the store. + * The default value of the annotation identifies the natural index for the type. + *

+ * + *

+ * Indexes allow for more efficient sorting of data read from the store. By annotating a field or + * "getter" method with this annotation, an index will be created that will provide sorting based on + * the string value of that field. + *

+ * + *

+ * Note that creating indices means more space will be needed, and maintenance operations like + * updating or deleting a value will become more expensive. + *

+ * + *

+ * Indices are restricted to String, integral types (byte, short, int, long, boolean), and arrays + * of those values. + *

+ */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.METHOD}) +public @interface KVIndex { + + public static final String NATURAL_INDEX_NAME = "__main__"; + + /** + * The name of the index to be created for the annotated entity. Must be unique within + * the class. Index names are not allowed to start with an underscore (that's reserved for + * internal use). The default value is the natural index name (which is always a copy index + * regardless of the annotation's values). + */ + String value() default NATURAL_INDEX_NAME; + + /** + * The name of the parent index of this index. By default there is no parent index, so the + * generated data can be retrieved without having to provide a parent value. + * + *

+ * If a parent index is defined, iterating over the data using the index will require providing + * a single value for the parent index. This serves as a rudimentary way to provide relationships + * between entities in the store. + *

+ */ + String parent() default ""; + + /** + * Whether to copy the instance's data to the index, instead of just storing a pointer to the + * data. The default behavior is to just store a reference; that saves disk space but is slower + * to read, since there's a level of indirection. + */ + boolean copy() default false; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java new file mode 100644 index 0000000000000..3be4b829b4d8d --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStore.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * Abstraction for a local key/value store for storing app data. + * + *

+ * There are two main features provided by the implementations of this interface: + *

+ * + *

Serialization

+ * + *

+ * If the underlying data store requires serialization, data will be serialized to and deserialized + * using a {@link KVStoreSerializer}, which can be customized by the application. The serializer is + * based on Jackson, so it supports all the Jackson annotations for controlling the serialization of + * app-defined types. + *

+ * + *

+ * Data is also automatically compressed to save disk space. + *

+ * + *

Automatic Key Management

+ * + *

+ * When using the built-in key management, the implementation will automatically create unique + * keys for each type written to the store. Keys are based on the type name, and always start + * with the "+" prefix character (so that it's easy to use both manual and automatic key + * management APIs without conflicts). + *

+ * + *

+ * Another feature of automatic key management is indexing; by annotating fields or methods of + * objects written to the store with {@link KVIndex}, indices are created to sort the data + * by the values of those properties. This makes it possible to provide sorting without having + * to load all instances of those types from the store. + *

+ * + *

+ * KVStore instances are thread-safe for both reads and writes. + *

+ */ +public interface KVStore extends Closeable { + + /** + * Returns app-specific metadata from the store, or null if it's not currently set. + * + *

+ * The metadata type is application-specific. This is a convenience method so that applications + * don't need to define their own keys for this information. + *

+ */ + T getMetadata(Class klass) throws Exception; + + /** + * Writes the given value in the store metadata key. + */ + void setMetadata(Object value) throws Exception; + + /** + * Read a specific instance of an object. + * + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws NoSuchElementException If an element with the given key does not exist. + */ + T read(Class klass, Object naturalKey) throws Exception; + + /** + * Writes the given object to the store, including indexed fields. Indices are updated based + * on the annotated fields of the object's class. + * + *

+ * Writes may be slower when the object already exists in the store, since it will involve + * updating existing indices. + *

+ * + * @param value The object to write. + */ + void write(Object value) throws Exception; + + /** + * Removes an object and all data related to it, like index entries, from the store. + * + * @param type The object's type. + * @param naturalKey The object's "natural key", which uniquely identifies it. Null keys + * are not allowed. + * @throws NoSuchElementException If an element with the given key does not exist. + */ + void delete(Class type, Object naturalKey) throws Exception; + + /** + * Returns a configurable view for iterating over entities of the given type. + */ + KVStoreView view(Class type) throws Exception; + + /** + * Returns the number of items of the given type currently in the store. + */ + long count(Class type) throws Exception; + + /** + * Returns the number of items of the given type which match the given indexed value. + */ + long count(Class type, String index, Object indexedValue) throws Exception; + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java new file mode 100644 index 0000000000000..3efdec9ed32be --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreIterator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.List; + +/** + * An iterator for KVStore. + * + *

+ * Iterators may keep references to resources that need to be closed. It's recommended that users + * explicitly close iterators after they're used. + *

+ */ +public interface KVStoreIterator extends Iterator, AutoCloseable { + + /** + * Retrieve multiple elements from the store. + * + * @param max Maximum number of elements to retrieve. + */ + List next(int max); + + /** + * Skip in the iterator. + * + * @return Whether there are items left after skipping. + */ + boolean skip(long n); + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java new file mode 100644 index 0000000000000..b84ec91cf67a0 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreSerializer.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Serializer used to translate between app-defined types and the LevelDB store. + * + *

+ * The serializer is based on Jackson, so values are written as JSON. It also allows "naked strings" + * and integers to be written as values directly, which will be written as UTF-8 strings. + *

+ */ +public class KVStoreSerializer { + + /** + * Object mapper used to process app-specific types. If an application requires a specific + * configuration of the mapper, it can subclass this serializer and add custom configuration + * to this object. + */ + protected final ObjectMapper mapper; + + public KVStoreSerializer() { + this.mapper = new ObjectMapper(); + } + + public final byte[] serialize(Object o) throws Exception { + if (o instanceof String) { + return ((String) o).getBytes(UTF_8); + } else { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + GZIPOutputStream out = new GZIPOutputStream(bytes); + try { + mapper.writeValue(out, o); + } finally { + out.close(); + } + return bytes.toByteArray(); + } + } + + @SuppressWarnings("unchecked") + public final T deserialize(byte[] data, Class klass) throws Exception { + if (klass.equals(String.class)) { + return (T) new String(data, UTF_8); + } else { + GZIPInputStream in = new GZIPInputStream(new ByteArrayInputStream(data)); + try { + return mapper.readValue(in, klass); + } finally { + in.close(); + } + } + } + + final byte[] serialize(long value) { + return String.valueOf(value).getBytes(UTF_8); + } + + final long deserializeLong(byte[] data) { + return Long.parseLong(new String(data, UTF_8)); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java new file mode 100644 index 0000000000000..b761640e6da8b --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVStoreView.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Iterator; +import java.util.Map; + +import com.google.common.base.Preconditions; + +/** + * A configurable view that allows iterating over values in a {@link KVStore}. + * + *

+ * The different methods can be used to configure the behavior of the iterator. Calling the same + * method multiple times is allowed; the most recent value will be used. + *

+ * + *

+ * The iterators returned by this view are of type {@link KVStoreIterator}; they auto-close + * when used in a for loop that exhausts their contents, but when used manually, they need + * to be closed explicitly unless all elements are read. + *

+ */ +public abstract class KVStoreView implements Iterable { + + final Class type; + + boolean ascending = true; + String index = KVIndex.NATURAL_INDEX_NAME; + Object first = null; + Object last = null; + Object parent = null; + long skip = 0L; + long max = Long.MAX_VALUE; + + public KVStoreView(Class type) { + this.type = type; + } + + /** + * Reverses the order of iteration. By default, iterates in ascending order. + */ + public KVStoreView reverse() { + ascending = !ascending; + return this; + } + + /** + * Iterates according to the given index. + */ + public KVStoreView index(String name) { + this.index = Preconditions.checkNotNull(name); + return this; + } + + /** + * Defines the value of the parent index when iterating over a child index. Only elements that + * match the parent index's value will be included in the iteration. + * + *

+ * Required for iterating over child indices, will generate an error if iterating over a + * parent-less index. + *

+ */ + public KVStoreView parent(Object value) { + this.parent = value; + return this; + } + + /** + * Iterates starting at the given value of the chosen index (inclusive). + */ + public KVStoreView first(Object value) { + this.first = value; + return this; + } + + /** + * Stops iteration at the given value of the chosen index (inclusive). + */ + public KVStoreView last(Object value) { + this.last = value; + return this; + } + + /** + * Stops iteration after a number of elements has been retrieved. + */ + public KVStoreView max(long max) { + Preconditions.checkArgument(max > 0L, "max must be positive."); + this.max = max; + return this; + } + + /** + * Skips a number of elements at the start of iteration. Skipped elements are not accounted + * when using {@link #max(long)}. + */ + public KVStoreView skip(long n) { + this.skip = n; + return this; + } + + /** + * Returns an iterator for the current configuration. + */ + public KVStoreIterator closeableIterator() throws Exception { + return (KVStoreIterator) iterator(); + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java new file mode 100644 index 0000000000000..90f2ff0079b8a --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/KVTypeInfo.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.base.Preconditions; + +/** + * Wrapper around types managed in a KVStore, providing easy access to their indexed fields. + */ +public class KVTypeInfo { + + private final Class type; + private final Map indices; + private final Map accessors; + + public KVTypeInfo(Class type) throws Exception { + this.type = type; + this.accessors = new HashMap<>(); + this.indices = new HashMap<>(); + + for (Field f : type.getDeclaredFields()) { + KVIndex idx = f.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + indices.put(idx.value(), idx); + f.setAccessible(true); + accessors.put(idx.value(), new FieldAccessor(f)); + } + } + + for (Method m : type.getDeclaredMethods()) { + KVIndex idx = m.getAnnotation(KVIndex.class); + if (idx != null) { + checkIndex(idx, indices); + Preconditions.checkArgument(m.getParameterTypes().length == 0, + "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + indices.put(idx.value(), idx); + m.setAccessible(true); + accessors.put(idx.value(), new MethodAccessor(m)); + } + } + + Preconditions.checkArgument(indices.containsKey(KVIndex.NATURAL_INDEX_NAME), + "No natural index defined for type %s.", type.getName()); + Preconditions.checkArgument(indices.get(KVIndex.NATURAL_INDEX_NAME).parent().isEmpty(), + "Natural index of %s cannot have a parent.", type.getName()); + + for (KVIndex idx : indices.values()) { + if (!idx.parent().isEmpty()) { + KVIndex parent = indices.get(idx.parent()); + Preconditions.checkArgument(parent != null, + "Cannot find parent %s of index %s.", idx.parent(), idx.value()); + Preconditions.checkArgument(parent.parent().isEmpty(), + "Parent index %s of index %s cannot be itself a child index.", idx.parent(), idx.value()); + } + } + } + + private void checkIndex(KVIndex idx, Map indices) { + Preconditions.checkArgument(idx.value() != null && !idx.value().isEmpty(), + "No name provided for index in type %s.", type.getName()); + Preconditions.checkArgument( + !idx.value().startsWith("_") || idx.value().equals(KVIndex.NATURAL_INDEX_NAME), + "Index name %s (in type %s) is not allowed.", idx.value(), type.getName()); + Preconditions.checkArgument(idx.parent().isEmpty() || !idx.parent().equals(idx.value()), + "Index %s cannot be parent of itself.", idx.value()); + Preconditions.checkArgument(!indices.containsKey(idx.value()), + "Duplicate index %s for type %s.", idx.value(), type.getName()); + } + + public Class getType() { + return type; + } + + public Object getIndexValue(String indexName, Object instance) throws Exception { + return getAccessor(indexName).get(instance); + } + + public Stream indices() { + return indices.values().stream(); + } + + Accessor getAccessor(String indexName) { + Accessor a = accessors.get(indexName); + Preconditions.checkArgument(a != null, "No index %s.", indexName); + return a; + } + + Accessor getParentAccessor(String indexName) { + KVIndex index = indices.get(indexName); + return index.parent().isEmpty() ? null : getAccessor(index.parent()); + } + + /** + * Abstracts the difference between invoking a Field and a Method. + */ + interface Accessor { + + Object get(Object instance) throws Exception; + + } + + private class FieldAccessor implements Accessor { + + private final Field field; + + FieldAccessor(Field field) { + this.field = field; + } + + @Override + public Object get(Object instance) throws Exception { + return field.get(instance); + } + + } + + private class MethodAccessor implements Accessor { + + private final Method method; + + MethodAccessor(Method method) { + this.method = method; + } + + @Override + public Object get(Object instance) throws Exception { + return method.invoke(instance); + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java new file mode 100644 index 0000000000000..08b22fd8265d8 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDB.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.fusesource.leveldbjni.JniDBFactory; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.iq80.leveldb.WriteBatch; + +/** + * Implementation of KVStore that uses LevelDB as the underlying data store. + */ +public class LevelDB implements KVStore { + + @VisibleForTesting + static final long STORE_VERSION = 1L; + + @VisibleForTesting + static final byte[] STORE_VERSION_KEY = "__version__".getBytes(UTF_8); + + /** DB key where app metadata is stored. */ + private static final byte[] METADATA_KEY = "__meta__".getBytes(UTF_8); + + /** DB key where type aliases are stored. */ + private static final byte[] TYPE_ALIASES_KEY = "__types__".getBytes(UTF_8); + + final AtomicReference _db; + final KVStoreSerializer serializer; + + /** + * Keep a mapping of class names to a shorter, unique ID managed by the store. This serves two + * purposes: make the keys stored on disk shorter, and spread out the keys, since class names + * will often have a long, redundant prefix (think "org.apache.spark."). + */ + private final ConcurrentMap typeAliases; + private final ConcurrentMap, LevelDBTypeInfo> types; + + public LevelDB(File path) throws Exception { + this(path, new KVStoreSerializer()); + } + + public LevelDB(File path, KVStoreSerializer serializer) throws Exception { + this.serializer = serializer; + this.types = new ConcurrentHashMap<>(); + + Options options = new Options(); + options.createIfMissing(!path.exists()); + this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); + + byte[] versionData = db().get(STORE_VERSION_KEY); + if (versionData != null) { + long version = serializer.deserializeLong(versionData); + if (version != STORE_VERSION) { + throw new UnsupportedStoreVersionException(); + } + } else { + db().put(STORE_VERSION_KEY, serializer.serialize(STORE_VERSION)); + } + + Map aliases; + try { + aliases = get(TYPE_ALIASES_KEY, TypeAliases.class).aliases; + } catch (NoSuchElementException e) { + aliases = new HashMap<>(); + } + typeAliases = new ConcurrentHashMap<>(aliases); + } + + @Override + public T getMetadata(Class klass) throws Exception { + try { + return get(METADATA_KEY, klass); + } catch (NoSuchElementException nsee) { + return null; + } + } + + @Override + public void setMetadata(Object value) throws Exception { + if (value != null) { + put(METADATA_KEY, value); + } else { + db().delete(METADATA_KEY); + } + } + + T get(byte[] key, Class klass) throws Exception { + byte[] data = db().get(key); + if (data == null) { + throw new NoSuchElementException(new String(key, UTF_8)); + } + return serializer.deserialize(data, klass); + } + + private void put(byte[] key, Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + db().put(key, serializer.serialize(value)); + } + + @Override + public T read(Class klass, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + byte[] key = getTypeInfo(klass).naturalIndex().start(null, naturalKey); + return get(key, klass); + } + + @Override + public void write(Object value) throws Exception { + Preconditions.checkArgument(value != null, "Null values are not allowed."); + LevelDBTypeInfo ti = getTypeInfo(value.getClass()); + + try (WriteBatch batch = db().createWriteBatch()) { + byte[] data = serializer.serialize(value); + synchronized (ti) { + Object existing; + try { + existing = get(ti.naturalIndex().entityKey(null, value), value.getClass()); + } catch (NoSuchElementException e) { + existing = null; + } + + PrefixCache cache = new PrefixCache(value); + byte[] naturalKey = ti.naturalIndex().toKey(ti.naturalIndex().getValue(value)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + byte[] prefix = cache.getPrefix(idx); + idx.add(batch, value, existing, data, naturalKey, prefix); + } + db().write(batch); + } + } + } + + @Override + public void delete(Class type, Object naturalKey) throws Exception { + Preconditions.checkArgument(naturalKey != null, "Null keys are not allowed."); + try (WriteBatch batch = db().createWriteBatch()) { + LevelDBTypeInfo ti = getTypeInfo(type); + byte[] key = ti.naturalIndex().start(null, naturalKey); + synchronized (ti) { + byte[] data = db().get(key); + if (data != null) { + Object existing = serializer.deserialize(data, type); + PrefixCache cache = new PrefixCache(existing); + byte[] keyBytes = ti.naturalIndex().toKey(ti.naturalIndex().getValue(existing)); + for (LevelDBTypeInfo.Index idx : ti.indices()) { + idx.remove(batch, existing, keyBytes, cache.getPrefix(idx)); + } + db().write(batch); + } + } + } catch (NoSuchElementException nse) { + // Ignore. + } + } + + @Override + public KVStoreView view(Class type) throws Exception { + return new KVStoreView(type) { + @Override + public Iterator iterator() { + try { + return new LevelDBIterator<>(LevelDB.this, this); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + }; + } + + @Override + public long count(Class type) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).naturalIndex(); + return idx.getCount(idx.end(null)); + } + + @Override + public long count(Class type, String index, Object indexedValue) throws Exception { + LevelDBTypeInfo.Index idx = getTypeInfo(type).index(index); + return idx.getCount(idx.end(null, indexedValue)); + } + + @Override + public void close() throws IOException { + DB _db = this._db.getAndSet(null); + if (_db == null) { + return; + } + + try { + _db.close(); + } catch (IOException ioe) { + throw ioe; + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } + } + + /** Returns metadata about indices for the given type. */ + LevelDBTypeInfo getTypeInfo(Class type) throws Exception { + LevelDBTypeInfo ti = types.get(type); + if (ti == null) { + LevelDBTypeInfo tmp = new LevelDBTypeInfo(this, type, getTypeAlias(type)); + ti = types.putIfAbsent(type, tmp); + if (ti == null) { + ti = tmp; + } + } + return ti; + } + + /** + * Try to avoid use-after close since that has the tendency of crashing the JVM. This doesn't + * prevent methods that retrieved the instance from using it after close, but hopefully will + * catch most cases; otherwise, we'll need some kind of locking. + */ + DB db() { + DB _db = this._db.get(); + if (_db == null) { + throw new IllegalStateException("DB is closed."); + } + return _db; + } + + private byte[] getTypeAlias(Class klass) throws Exception { + byte[] alias = typeAliases.get(klass.getName()); + if (alias == null) { + synchronized (typeAliases) { + byte[] tmp = String.valueOf(typeAliases.size()).getBytes(UTF_8); + alias = typeAliases.putIfAbsent(klass.getName(), tmp); + if (alias == null) { + alias = tmp; + put(TYPE_ALIASES_KEY, new TypeAliases(typeAliases)); + } + } + } + return alias; + } + + /** Needs to be public for Jackson. */ + public static class TypeAliases { + + public Map aliases; + + TypeAliases(Map aliases) { + this.aliases = aliases; + } + + TypeAliases() { + this(null); + } + + } + + private static class PrefixCache { + + private final Object entity; + private final Map prefixes; + + PrefixCache(Object entity) { + this.entity = entity; + this.prefixes = new HashMap<>(); + } + + byte[] getPrefix(LevelDBTypeInfo.Index idx) throws Exception { + byte[] prefix = null; + if (idx.isChild()) { + prefix = prefixes.get(idx.parent()); + if (prefix == null) { + prefix = idx.parent().childPrefix(idx.parent().getValue(entity)); + prefixes.put(idx.parent(), prefix); + } + } + return prefix; + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java new file mode 100644 index 0000000000000..a5d0f9f4fb373 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBIterator.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.DBIterator; + +class LevelDBIterator implements KVStoreIterator { + + private final LevelDB db; + private final boolean ascending; + private final DBIterator it; + private final Class type; + private final LevelDBTypeInfo ti; + private final LevelDBTypeInfo.Index index; + private final byte[] indexKeyPrefix; + private final byte[] end; + private final long max; + + private boolean checkedNext; + private byte[] next; + private boolean closed; + private long count; + + LevelDBIterator(LevelDB db, KVStoreView params) throws Exception { + this.db = db; + this.ascending = params.ascending; + this.it = db.db().iterator(); + this.type = params.type; + this.ti = db.getTypeInfo(type); + this.index = ti.index(params.index); + this.max = params.max; + + Preconditions.checkArgument(!index.isChild() || params.parent != null, + "Cannot iterate over child index %s without parent value.", params.index); + byte[] parent = index.isChild() ? index.parent().childPrefix(params.parent) : null; + + this.indexKeyPrefix = index.keyPrefix(parent); + + byte[] firstKey; + if (params.first != null) { + if (ascending) { + firstKey = index.start(parent, params.first); + } else { + firstKey = index.end(parent, params.first); + } + } else if (ascending) { + firstKey = index.keyPrefix(parent); + } else { + firstKey = index.end(parent); + } + it.seek(firstKey); + + byte[] end = null; + if (ascending) { + if (params.last != null) { + end = index.end(parent, params.last); + } else { + end = index.end(parent); + } + } else { + if (params.last != null) { + end = index.start(parent, params.last); + } + if (it.hasNext()) { + // When descending, the caller may have set up the start of iteration at a non-existant + // entry that is guaranteed to be after the desired entry. For example, if you have a + // compound key (a, b) where b is a, integer, you may seek to the end of the elements that + // have the same "a" value by specifying Integer.MAX_VALUE for "b", and that value may not + // exist in the database. So need to check here whether the next value actually belongs to + // the set being returned by the iterator before advancing. + byte[] nextKey = it.peekNext().getKey(); + if (compare(nextKey, indexKeyPrefix) <= 0) { + it.next(); + } + } + } + this.end = end; + + if (params.skip > 0) { + skip(params.skip); + } + } + + @Override + public boolean hasNext() { + if (!checkedNext && !closed) { + next = loadNext(); + checkedNext = true; + } + if (!closed && next == null) { + try { + close(); + } catch (IOException ioe) { + throw Throwables.propagate(ioe); + } + } + return next != null; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + checkedNext = false; + + try { + T ret; + if (index == null || index.isCopy()) { + ret = db.serializer.deserialize(next, type); + } else { + byte[] key = ti.buildKey(false, ti.naturalIndex().keyPrefix(null), next); + ret = db.get(key, type); + } + next = null; + return ret; + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public List next(int max) { + List list = new ArrayList<>(max); + while (hasNext() && list.size() < max) { + list.add(next()); + } + return list; + } + + @Override + public boolean skip(long n) { + long skipped = 0; + while (skipped < n) { + if (next != null) { + checkedNext = false; + next = null; + skipped++; + continue; + } + + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + checkedNext = true; + return false; + } + + Map.Entry e = ascending ? it.next() : it.prev(); + if (!isEndMarker(e.getKey())) { + skipped++; + } + } + + return hasNext(); + } + + @Override + public synchronized void close() throws IOException { + if (!closed) { + it.close(); + closed = true; + } + } + + private byte[] loadNext() { + if (count >= max) { + return null; + } + + try { + while (true) { + boolean hasNext = ascending ? it.hasNext() : it.hasPrev(); + if (!hasNext) { + return null; + } + + Map.Entry nextEntry; + try { + // Avoid races if another thread is updating the DB. + nextEntry = ascending ? it.next() : it.prev(); + } catch (NoSuchElementException e) { + return null; + } + + byte[] nextKey = nextEntry.getKey(); + // Next key is not part of the index, stop. + if (!startsWith(nextKey, indexKeyPrefix)) { + return null; + } + + // If the next key is an end marker, then skip it. + if (isEndMarker(nextKey)) { + continue; + } + + // If there's a known end key and iteration has gone past it, stop. + if (end != null) { + int comp = compare(nextKey, end) * (ascending ? 1 : -1); + if (comp > 0) { + return null; + } + } + + count++; + + // Next element is part of the iteration, return it. + return nextEntry.getValue(); + } + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @VisibleForTesting + static boolean startsWith(byte[] key, byte[] prefix) { + if (key.length < prefix.length) { + return false; + } + + for (int i = 0; i < prefix.length; i++) { + if (key[i] != prefix[i]) { + return false; + } + } + + return true; + } + + private boolean isEndMarker(byte[] key) { + return (key.length > 2 && + key[key.length - 2] == LevelDBTypeInfo.KEY_SEPARATOR && + key[key.length - 1] == LevelDBTypeInfo.END_MARKER[0]); + } + + static int compare(byte[] a, byte[] b) { + int diff = 0; + int minLen = Math.min(a.length, b.length); + for (int i = 0; i < minLen; i++) { + diff += (a[i] - b[i]); + if (diff != 0) { + return diff; + } + } + + return a.length - b.length; + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java new file mode 100644 index 0000000000000..3ab17dbd03ca7 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/LevelDBTypeInfo.java @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.iq80.leveldb.WriteBatch; + +/** + * Holds metadata about app-specific types stored in LevelDB. Serves as a cache for data collected + * via reflection, to make it cheaper to access it multiple times. + * + *

+ * The hierarchy of keys stored in LevelDB looks roughly like the following. This hierarchy ensures + * that iteration over indices is easy, and that updating values in the store is not overly + * expensive. Of note, indices choose using more disk space (one value per key) instead of keeping + * lists of pointers, which would be more expensive to update at runtime. + *

+ * + *

+ * Indentation defines when a sub-key lives under a parent key. In LevelDB, this means the full + * key would be the concatenation of everything up to that point in the hierarchy, with each + * component separated by a NULL byte. + *

+ * + *
+ * +TYPE_NAME
+ *   NATURAL_INDEX
+ *     +NATURAL_KEY
+ *     -
+ *   -NATURAL_INDEX
+ *   INDEX_NAME
+ *     +INDEX_VALUE
+ *       +NATURAL_KEY
+ *     -INDEX_VALUE
+ *     .INDEX_VALUE
+ *       CHILD_INDEX_NAME
+ *         +CHILD_INDEX_VALUE
+ *           NATURAL_KEY_OR_DATA
+ *         -
+ *   -INDEX_NAME
+ * 
+ * + *

+ * Entity data (either the entity's natural key or a copy of the data) is stored in all keys + * that end with "+". A count of all objects that match a particular top-level index + * value is kept at the end marker ("-"). A count is also kept at the natural index's end + * marker, to make it easy to retrieve the number of all elements of a particular type. + *

+ * + *

+ * To illustrate, given a type "Foo", with a natural index and a second index called "bar", you'd + * have these keys and values in the store for two instances, one with natural key "key1" and the + * other "key2", both with value "yes" for "bar": + *

+ * + *
+ * Foo __main__ +key1   [data for instance 1]
+ * Foo __main__ +key2   [data for instance 2]
+ * Foo __main__ -       [count of all Foo]
+ * Foo bar +yes +key1   [instance 1 key or data, depending on index type]
+ * Foo bar +yes +key2   [instance 2 key or data, depending on index type]
+ * Foo bar +yes -       [count of all Foo with "bar=yes" ]
+ * 
+ * + *

+ * Note that all indexed values are prepended with "+", even if the index itself does not have an + * explicit end marker. This allows for easily skipping to the end of an index by telling LevelDB + * to seek to the "phantom" end marker of the index. Throughout the code and comments, this part + * of the full LevelDB key is generally referred to as the "index value" of the entity. + *

+ * + *

+ * Child indices are stored after their parent index. In the example above, let's assume there is + * a child index "child", whose parent is "bar". If both instances have value "no" for this field, + * the data in the store would look something like the following: + *

+ * + *
+ * ...
+ * Foo bar +yes -
+ * Foo bar .yes .child +no +key1   [instance 1 key or data, depending on index type]
+ * Foo bar .yes .child +no +key2   [instance 2 key or data, depending on index type]
+ * ...
+ * 
+ */ +class LevelDBTypeInfo { + + static final byte[] END_MARKER = new byte[] { '-' }; + static final byte ENTRY_PREFIX = (byte) '+'; + static final byte KEY_SEPARATOR = 0x0; + static byte TRUE = (byte) '1'; + static byte FALSE = (byte) '0'; + + private static final byte SECONDARY_IDX_PREFIX = (byte) '.'; + private static final byte POSITIVE_MARKER = (byte) '='; + private static final byte NEGATIVE_MARKER = (byte) '*'; + private static final byte[] HEX_BYTES = new byte[] { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' + }; + + private final LevelDB db; + private final Class type; + private final Map indices; + private final byte[] typePrefix; + + LevelDBTypeInfo(LevelDB db, Class type, byte[] alias) throws Exception { + this.db = db; + this.type = type; + this.indices = new HashMap<>(); + + KVTypeInfo ti = new KVTypeInfo(type); + + // First create the parent indices, then the child indices. + ti.indices().forEach(idx -> { + if (idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), null)); + } + }); + ti.indices().forEach(idx -> { + if (!idx.parent().isEmpty()) { + indices.put(idx.value(), new Index(idx, ti.getAccessor(idx.value()), + indices.get(idx.parent()))); + } + }); + + this.typePrefix = alias; + } + + Class type() { + return type; + } + + byte[] keyPrefix() { + return typePrefix; + } + + Index naturalIndex() { + return index(KVIndex.NATURAL_INDEX_NAME); + } + + Index index(String name) { + Index i = indices.get(name); + Preconditions.checkArgument(i != null, "Index %s does not exist for type %s.", name, + type.getName()); + return i; + } + + Collection indices() { + return indices.values(); + } + + byte[] buildKey(byte[]... components) { + return buildKey(true, components); + } + + byte[] buildKey(boolean addTypePrefix, byte[]... components) { + int len = 0; + if (addTypePrefix) { + len += typePrefix.length + 1; + } + for (byte[] comp : components) { + len += comp.length; + } + len += components.length - 1; + + byte[] dest = new byte[len]; + int written = 0; + + if (addTypePrefix) { + System.arraycopy(typePrefix, 0, dest, 0, typePrefix.length); + dest[typePrefix.length] = KEY_SEPARATOR; + written += typePrefix.length + 1; + } + + for (byte[] comp : components) { + System.arraycopy(comp, 0, dest, written, comp.length); + written += comp.length; + if (written < dest.length) { + dest[written] = KEY_SEPARATOR; + written++; + } + } + + return dest; + } + + /** + * Models a single index in LevelDB. See top-level class's javadoc for a description of how the + * keys are generated. + */ + class Index { + + private final boolean copy; + private final boolean isNatural; + private final byte[] name; + private final KVTypeInfo.Accessor accessor; + private final Index parent; + + private Index(KVIndex self, KVTypeInfo.Accessor accessor, Index parent) { + byte[] name = self.value().getBytes(UTF_8); + if (parent != null) { + byte[] child = new byte[name.length + 1]; + child[0] = SECONDARY_IDX_PREFIX; + System.arraycopy(name, 0, child, 1, name.length); + } + + this.name = name; + this.isNatural = self.value().equals(KVIndex.NATURAL_INDEX_NAME); + this.copy = isNatural || self.copy(); + this.accessor = accessor; + this.parent = parent; + } + + boolean isCopy() { + return copy; + } + + boolean isChild() { + return parent != null; + } + + Index parent() { + return parent; + } + + /** + * Creates a key prefix for child indices of this index. This allows the prefix to be + * calculated only once, avoiding redundant work when multiple child indices of the + * same parent index exist. + */ + byte[] childPrefix(Object value) throws Exception { + Preconditions.checkState(parent == null, "Not a parent index."); + return buildKey(name, toParentKey(value)); + } + + /** + * Gets the index value for a particular entity (which is the value of the field or method + * tagged with the index annotation). This is used as part of the LevelDB key where the + * entity (or its id) is stored. + */ + Object getValue(Object entity) throws Exception { + return accessor.get(entity); + } + + private void checkParent(byte[] prefix) { + if (prefix != null) { + Preconditions.checkState(parent != null, "Parent prefix provided for parent index."); + } else { + Preconditions.checkState(parent == null, "Parent prefix missing for child index."); + } + } + + /** The prefix for all keys that belong to this index. */ + byte[] keyPrefix(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name) : buildKey(name); + } + + /** + * The key where to start ascending iteration for entities whose value for the indexed field + * match the given value. + */ + byte[] start(byte[] prefix, Object value) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value)) + : buildKey(name, toKey(value)); + } + + /** The key for the index's end marker. */ + byte[] end(byte[] prefix) { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, END_MARKER) + : buildKey(name, END_MARKER); + } + + /** The key for the end marker for entries with the given value. */ + byte[] end(byte[] prefix, Object value) throws Exception { + checkParent(prefix); + return (parent != null) ? buildKey(false, prefix, name, toKey(value), END_MARKER) + : buildKey(name, toKey(value), END_MARKER); + } + + /** The full key in the index that identifies the given entity. */ + byte[] entityKey(byte[] prefix, Object entity) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, toKey(naturalIndex().getValue(entity))); + } + return entityKey; + } + + private void updateCount(WriteBatch batch, byte[] key, long delta) throws Exception { + long updated = getCount(key) + delta; + if (updated > 0) { + batch.put(key, db.serializer.serialize(updated)); + } else { + batch.delete(key); + } + } + + private void addOrRemove( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + Object indexValue = getValue(entity); + Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", + name, type.getName()); + + byte[] entityKey = start(prefix, indexValue); + if (!isNatural) { + entityKey = buildKey(false, entityKey, naturalKey); + } + + boolean needCountUpdate = (existing == null); + + // Check whether there's a need to update the index. The index needs to be updated in two + // cases: + // + // - There is no existing value for the entity, so a new index value will be added. + // - If there is a previously stored value for the entity, and the index value for the + // current index does not match the new value, the old entry needs to be deleted and + // the new one added. + // + // Natural indices don't need to be checked, because by definition both old and new entities + // will have the same key. The put() call is all that's needed in that case. + // + // Also check whether we need to update the counts. If the indexed value is changing, we + // need to decrement the count at the old index value, and the new indexed value count needs + // to be incremented. + if (existing != null && !isNatural) { + byte[] oldPrefix = null; + Object oldIndexedValue = getValue(existing); + boolean removeExisting = !indexValue.equals(oldIndexedValue); + if (!removeExisting && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + removeExisting = LevelDBIterator.compare(prefix, oldPrefix) != 0; + } + + if (removeExisting) { + if (oldPrefix == null && isChild()) { + oldPrefix = parent().childPrefix(parent().getValue(existing)); + } + + byte[] oldKey = entityKey(oldPrefix, existing); + batch.delete(oldKey); + + // If the indexed value has changed, we need to update the counts at the old and new + // end markers for the indexed value. + if (!isChild()) { + byte[] oldCountKey = end(null, oldIndexedValue); + updateCount(batch, oldCountKey, -1L); + needCountUpdate = true; + } + } + } + + if (data != null) { + byte[] stored = copy ? data : naturalKey; + batch.put(entityKey, stored); + } else { + batch.delete(entityKey); + } + + if (needCountUpdate && !isChild()) { + long delta = data != null ? 1L : -1L; + byte[] countKey = isNatural ? end(prefix) : end(prefix, indexValue); + updateCount(batch, countKey, delta); + } + } + + /** + * Add an entry to the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being added to the index. + * @param existing The entity being replaced in the index, or null. + * @param data Serialized entity to store (when storing the entity, not a reference). + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void add( + WriteBatch batch, + Object entity, + Object existing, + byte[] data, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, existing, data, naturalKey, prefix); + } + + /** + * Remove a value from the index. + * + * @param batch Write batch with other related changes. + * @param entity The entity being removed, to identify the index entry to modify. + * @param naturalKey The value's natural key (to avoid re-computing it for every index). + * @param prefix The parent index prefix, if this is a child index. + */ + void remove( + WriteBatch batch, + Object entity, + byte[] naturalKey, + byte[] prefix) throws Exception { + addOrRemove(batch, entity, null, null, naturalKey, prefix); + } + + long getCount(byte[] key) throws Exception { + byte[] data = db.db().get(key); + return data != null ? db.serializer.deserializeLong(data) : 0; + } + + byte[] toParentKey(Object value) { + return toKey(value, SECONDARY_IDX_PREFIX); + } + + byte[] toKey(Object value) { + return toKey(value, ENTRY_PREFIX); + } + + /** + * Translates a value to be used as part of the store key. + * + * Integral numbers are encoded as a string in a way that preserves lexicographical + * ordering. The string is prepended with a marker telling whether the number is negative + * or positive ("*" for negative and "=" for positive are used since "-" and "+" have the + * opposite of the desired order), and then the number is encoded into a hex string (so + * it occupies twice the number of bytes as the original type). + * + * Arrays are encoded by encoding each element separately, separated by KEY_SEPARATOR. + */ + byte[] toKey(Object value, byte prefix) { + final byte[] result; + + if (value instanceof String) { + byte[] str = ((String) value).getBytes(UTF_8); + result = new byte[str.length + 1]; + result[0] = prefix; + System.arraycopy(str, 0, result, 1, str.length); + } else if (value instanceof Boolean) { + result = new byte[] { prefix, (Boolean) value ? TRUE : FALSE }; + } else if (value.getClass().isArray()) { + int length = Array.getLength(value); + byte[][] components = new byte[length][]; + for (int i = 0; i < length; i++) { + components[i] = toKey(Array.get(value, i)); + } + result = buildKey(false, components); + } else { + int bytes; + + if (value instanceof Integer) { + bytes = Integer.SIZE; + } else if (value instanceof Long) { + bytes = Long.SIZE; + } else if (value instanceof Short) { + bytes = Short.SIZE; + } else if (value instanceof Byte) { + bytes = Byte.SIZE; + } else { + throw new IllegalArgumentException(String.format("Type %s not allowed as key.", + value.getClass().getName())); + } + + bytes = bytes / Byte.SIZE; + + byte[] key = new byte[bytes * 2 + 2]; + long longValue = ((Number) value).longValue(); + key[0] = prefix; + key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + + for (int i = 0; i < key.length - 2; i++) { + int masked = (int) ((longValue >>> (4 * i)) & 0xF); + key[key.length - i - 1] = HEX_BYTES[masked]; + } + + result = key; + } + + return result; + } + + } + +} diff --git a/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java new file mode 100644 index 0000000000000..2ed246e4f4c97 --- /dev/null +++ b/common/kvstore/src/main/java/org/apache/spark/kvstore/UnsupportedStoreVersionException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.IOException; + +/** + * Exception thrown when the store implementation is not compatible with the underlying data. + */ +public class UnsupportedStoreVersionException extends IOException { + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java new file mode 100644 index 0000000000000..afb72b8689223 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/CustomType1.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import com.google.common.base.Objects; + +public class CustomType1 { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex(value = "name", copy = true) + public String name; + + @KVIndex("int") + public int num; + + @KVIndex(value = "child", parent = "id") + public String child; + + @Override + public boolean equals(Object o) { + if (o instanceof CustomType1) { + CustomType1 other = (CustomType1) o; + return id.equals(other.id) && name.equals(other.name); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("key", key) + .add("id", id) + .add("name", name) + .add("num", num) + .toString(); + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java new file mode 100644 index 0000000000000..8549712213393 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/DBIteratorSuite.java @@ -0,0 +1,506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Predicate; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +public abstract class DBIteratorSuite { + + private static final Logger LOG = LoggerFactory.getLogger(DBIteratorSuite.class); + + private static final int MIN_ENTRIES = 42; + private static final int MAX_ENTRIES = 1024; + private static final Random RND = new Random(); + + private static List allEntries; + private static List clashingEntries; + private static KVStore db; + + private static interface BaseComparator extends Comparator { + /** + * Returns a comparator that falls back to natural order if this comparator's ordering + * returns equality for two elements. Used to mimic how the index sorts things internally. + */ + default BaseComparator fallback() { + return (t1, t2) -> { + int result = BaseComparator.this.compare(t1, t2); + if (result != 0) { + return result; + } + + return t1.key.compareTo(t2.key); + }; + } + + /** Reverses the order of this comparator. */ + default BaseComparator reverse() { + return (t1, t2) -> -BaseComparator.this.compare(t1, t2); + } + } + + private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); + private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); + private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); + + /** + * Implementations should override this method; it is called only once, before all tests are + * run. Any state can be safely stored in static variables and cleaned up in a @AfterClass + * handler. + */ + protected abstract KVStore createStore() throws Exception; + + @BeforeClass + public static void setupClass() { + long seed = RND.nextLong(); + LOG.info("Random seed: {}", seed); + RND.setSeed(seed); + } + + @AfterClass + public static void cleanupData() throws Exception { + allEntries = null; + db = null; + } + + @Before + public void setup() throws Exception { + if (db != null) { + return; + } + + db = createStore(); + + int count = RND.nextInt(MAX_ENTRIES) + MIN_ENTRIES; + + allEntries = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + RND.nextInt(MAX_ENTRIES); + t.num = RND.nextInt(MAX_ENTRIES); + t.child = "child" + (i % MIN_ENTRIES); + allEntries.add(t); + } + + // Shuffle the entries to avoid the insertion order matching the natural ordering. Just in case. + Collections.shuffle(allEntries, RND); + for (CustomType1 e : allEntries) { + db.write(e); + } + + // Pick the first generated value, and forcefully create a few entries that will clash + // with the indexed values (id and name), to make sure the index behaves correctly when + // multiple entities are indexed by the same value. + // + // This also serves as a test for the test code itself, to make sure it's sorting indices + // the same way the store is expected to. + CustomType1 first = allEntries.get(0); + clashingEntries = new ArrayList<>(); + + int clashCount = RND.nextInt(MIN_ENTRIES) + 1; + for (int i = 0; i < clashCount; i++) { + CustomType1 t = new CustomType1(); + t.key = "n-key" + (count + i); + t.id = first.id; + t.name = first.name; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + clashingEntries.add(t); + db.write(t); + } + + // Create another entry that could cause problems: take the first entry, and make its indexed + // name be an extension of the existing ones, to make sure the implementation sorts these + // correctly even considering the separator character (shorter strings first). + CustomType1 t = new CustomType1(); + t.key = "extended-key-0"; + t.id = first.id; + t.name = first.name + "a"; + t.num = first.num; + t.child = first.child; + allEntries.add(t); + db.write(t); + } + + @Test + public void naturalIndex() throws Exception { + testIteration(NATURAL_ORDER, view(), null, null); + } + + @Test + public void refIndex() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id"), null, null); + } + + @Test + public void copyIndex() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name"), null, null); + } + + @Test + public void numericIndex() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int"), null, null); + } + + @Test + public void childIndex() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id), null, null); + } + + @Test + public void naturalIndexDescending() throws Exception { + testIteration(NATURAL_ORDER, view().reverse(), null, null); + } + + @Test + public void refIndexDescending() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").reverse(), null, null); + } + + @Test + public void copyIndexDescending() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").reverse(), null, null); + } + + @Test + public void numericIndexDescending() throws Exception { + testIteration(NUMERIC_INDEX_ORDER, view().index("int").reverse(), null, null); + } + + @Test + public void childIndexDescending() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).reverse(), null, null); + } + + @Test + public void naturalIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().first(first.key), first, null); + } + + @Test + public void refIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").first(first.num), first, null); + } + + @Test + public void childIndexWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).first(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().first(first.key), first, null); + } + + @Test + public void refIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").first(first.id), first, null); + } + + @Test + public void copyIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").first(first.name), first, null); + } + + @Test + public void numericIndexDescendingWithStart() throws Exception { + CustomType1 first = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").first(first.num), first, null); + } + + @Test + public void childIndexDescendingWithStart() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, + view().index("child").parent(any.id).first(any.child).reverse(), null, null); + } + + @Test + public void naturalIndexWithSkip() throws Exception { + testIteration(NATURAL_ORDER, view().skip(pickCount()), null, null); + } + + @Test + public void refIndexWithSkip() throws Exception { + testIteration(REF_INDEX_ORDER, view().index("id").skip(pickCount()), null, null); + } + + @Test + public void copyIndexWithSkip() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").skip(pickCount()), null, null); + } + + @Test + public void childIndexWithSkip() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).skip(pickCount()), + null, null); + } + + @Test + public void naturalIndexWithMax() throws Exception { + testIteration(NATURAL_ORDER, view().max(pickCount()), null, null); + } + + @Test + public void copyIndexWithMax() throws Exception { + testIteration(COPY_INDEX_ORDER, view().index("name").max(pickCount()), null, null); + } + + @Test + public void childIndexWithMax() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).max(pickCount()), null, + null); + } + + @Test + public void naturalIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().last(last.key), null, last); + } + + @Test + public void refIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().index("name").last(last.name), null, last); + } + + @Test + public void numericIndexWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().index("int").last(last.num), null, last); + } + + @Test + public void childIndexWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child), null, + null); + } + + @Test + public void naturalIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NATURAL_ORDER, view().reverse().last(last.key), null, last); + } + + @Test + public void refIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(REF_INDEX_ORDER, view().reverse().index("id").last(last.id), null, last); + } + + @Test + public void copyIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(COPY_INDEX_ORDER, view().reverse().index("name").last(last.name), + null, last); + } + + @Test + public void numericIndexDescendingWithLast() throws Exception { + CustomType1 last = pickLimit(); + testIteration(NUMERIC_INDEX_ORDER, view().reverse().index("int").last(last.num), + null, last); + } + + @Test + public void childIndexDescendingWithLast() throws Exception { + CustomType1 any = pickLimit(); + testIteration(CHILD_INDEX_ORDER, view().index("child").parent(any.id).last(any.child).reverse(), + null, null); + } + + @Test + public void testRefWithIntNaturalKey() throws Exception { + LevelDBSuite.IntKeyType i = new LevelDBSuite.IntKeyType(); + i.key = 1; + i.id = "1"; + i.values = Arrays.asList("1"); + + db.write(i); + + try(KVStoreIterator it = db.view(i.getClass()).closeableIterator()) { + Object read = it.next(); + assertEquals(i, read); + } + } + + private CustomType1 pickLimit() { + // Picks an element that has clashes with other elements in the given index. + return clashingEntries.get(RND.nextInt(clashingEntries.size())); + } + + private int pickCount() { + int count = RND.nextInt(allEntries.size() / 2); + return Math.max(count, 1); + } + + /** + * Compares the two values and falls back to comparing the natural key of CustomType1 + * if they're the same, to mimic the behavior of the indexing code. + */ + private > int compareWithFallback( + T v1, + T v2, + CustomType1 ct1, + CustomType1 ct2) { + int result = v1.compareTo(v2); + if (result != 0) { + return result; + } + + return ct1.key.compareTo(ct2.key); + } + + private void testIteration( + final BaseComparator order, + final KVStoreView params, + final CustomType1 first, + final CustomType1 last) throws Exception { + List indexOrder = sortBy(order.fallback()); + if (!params.ascending) { + indexOrder = Lists.reverse(indexOrder); + } + + Iterable expected = indexOrder; + BaseComparator expectedOrder = params.ascending ? order : order.reverse(); + + if (params.parent != null) { + expected = Iterables.filter(expected, v -> params.parent.equals(v.id)); + } + + if (first != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(first, v) <= 0); + } + + if (last != null) { + expected = Iterables.filter(expected, v -> expectedOrder.compare(v, last) <= 0); + } + + if (params.skip > 0) { + expected = Iterables.skip(expected, (int) params.skip); + } + + if (params.max != Long.MAX_VALUE) { + expected = Iterables.limit(expected, (int) params.max); + } + + List actual = collect(params); + compareLists(expected, actual); + } + + /** Could use assertEquals(), but that creates hard to read errors for large lists. */ + private void compareLists(Iterable expected, List actual) { + Iterator expectedIt = expected.iterator(); + Iterator actualIt = actual.iterator(); + + int count = 0; + while (expectedIt.hasNext()) { + if (!actualIt.hasNext()) { + break; + } + count++; + assertEquals(expectedIt.next(), actualIt.next()); + } + + String message; + Object[] remaining; + int expectedCount = count; + int actualCount = count; + + if (expectedIt.hasNext()) { + remaining = Iterators.toArray(expectedIt, Object.class); + expectedCount += remaining.length; + message = "missing"; + } else { + remaining = Iterators.toArray(actualIt, Object.class); + actualCount += remaining.length; + message = "stray"; + } + + assertEquals(String.format("Found %s elements: %s", message, Arrays.asList(remaining)), + expectedCount, actualCount); + } + + private KVStoreView view() throws Exception { + return db.view(CustomType1.class); + } + + private List collect(KVStoreView view) throws Exception { + return Arrays.asList(Iterables.toArray(view, CustomType1.class)); + } + + private List sortBy(Comparator comp) { + List copy = new ArrayList<>(allEntries); + Collections.sort(copy, comp); + return copy; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java new file mode 100644 index 0000000000000..5e33606b12dd4 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBBenchmark.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Slf4jReporter; +import com.codahale.metrics.Snapshot; +import com.codahale.metrics.Timer; +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * A set of small benchmarks for the LevelDB implementation. + * + * The benchmarks are run over two different types (one with just a natural index, and one + * with a ref index), over a set of 2^20 elements, and the following tests are performed: + * + * - write (then update) elements in sequential natural key order + * - write (then update) elements in random natural key order + * - iterate over natural index, ascending and descending + * - iterate over ref index, ascending and descending + */ +@Ignore +public class LevelDBBenchmark { + + private static final int COUNT = 1024; + private static final AtomicInteger IDGEN = new AtomicInteger(); + private static final MetricRegistry metrics = new MetricRegistry(); + private static final Timer dbCreation = metrics.timer("dbCreation"); + private static final Timer dbClose = metrics.timer("dbClose"); + + private LevelDB db; + private File dbpath; + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + try(Timer.Context ctx = dbCreation.time()) { + db = new LevelDB(dbpath); + } + } + + @After + public void cleanup() throws Exception { + if (db != null) { + try(Timer.Context ctx = dbClose.time()) { + db.close(); + } + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @AfterClass + public static void report() { + if (metrics.getTimers().isEmpty()) { + return; + } + + int headingPrefix = 0; + for (Map.Entry e : metrics.getTimers().entrySet()) { + headingPrefix = Math.max(e.getKey().length(), headingPrefix); + } + headingPrefix += 4; + + StringBuilder heading = new StringBuilder(); + for (int i = 0; i < headingPrefix; i++) { + heading.append(" "); + } + heading.append("\tcount"); + heading.append("\tmean"); + heading.append("\tmin"); + heading.append("\tmax"); + heading.append("\t95th"); + System.out.println(heading); + + for (Map.Entry e : metrics.getTimers().entrySet()) { + StringBuilder row = new StringBuilder(); + row.append(e.getKey()); + for (int i = 0; i < headingPrefix - e.getKey().length(); i++) { + row.append(" "); + } + + Snapshot s = e.getValue().getSnapshot(); + row.append("\t").append(e.getValue().getCount()); + row.append("\t").append(toMs(s.getMean())); + row.append("\t").append(toMs(s.getMin())); + row.append("\t").append(toMs(s.getMax())); + row.append("\t").append(toMs(s.get95thPercentile())); + + System.out.println(row); + } + + Slf4jReporter.forRegistry(metrics).outputTo(LoggerFactory.getLogger(LevelDBBenchmark.class)) + .build().report(); + } + + private static String toMs(double nanos) { + return String.format("%.3f", nanos / 1000 / 1000); + } + + @Test + public void sequentialWritesNoIndex() throws Exception { + List entries = createSimpleType(); + writeAll(entries, "sequentialWritesNoIndex"); + writeAll(entries, "sequentialUpdatesNoIndex"); + deleteNoIndex(entries, "sequentialDeleteNoIndex"); + } + + @Test + public void randomWritesNoIndex() throws Exception { + List entries = createSimpleType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesNoIndex"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesNoIndex"); + + Collections.shuffle(entries); + deleteNoIndex(entries, "randomDeletesNoIndex"); + } + + @Test + public void sequentialWritesIndexedType() throws Exception { + List entries = createIndexedType(); + writeAll(entries, "sequentialWritesIndexed"); + writeAll(entries, "sequentialUpdatesIndexed"); + deleteIndexed(entries, "sequentialDeleteIndexed"); + } + + @Test + public void randomWritesIndexedTypeAndIteration() throws Exception { + List entries = createIndexedType(); + + Collections.shuffle(entries); + writeAll(entries, "randomWritesIndexed"); + + Collections.shuffle(entries); + writeAll(entries, "randomUpdatesIndexed"); + + // Run iteration benchmarks here since we've gone through the trouble of writing all + // the data already. + KVStoreView view = db.view(IndexedType.class); + iterate(view, "naturalIndex"); + iterate(view.reverse(), "naturalIndexDescending"); + iterate(view.index("name"), "refIndex"); + iterate(view.index("name").reverse(), "refIndexDescending"); + + Collections.shuffle(entries); + deleteIndexed(entries, "randomDeleteIndexed"); + } + + private void iterate(KVStoreView view, String name) throws Exception { + Timer create = metrics.timer(name + "CreateIterator"); + Timer iter = metrics.timer(name + "Iteration"); + KVStoreIterator it = null; + { + // Create the iterator several times, just to have multiple data points. + for (int i = 0; i < 1024; i++) { + if (it != null) { + it.close(); + } + try(Timer.Context ctx = create.time()) { + it = view.closeableIterator(); + } + } + } + + for (; it.hasNext(); ) { + try(Timer.Context ctx = iter.time()) { + it.next(); + } + } + } + + private void writeAll(List entries, String timerName) throws Exception { + Timer timer = newTimer(timerName); + for (Object o : entries) { + try(Timer.Context ctx = timer.time()) { + db.write(o); + } + } + } + + private void deleteNoIndex(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (SimpleType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private void deleteIndexed(List entries, String timerName) throws Exception { + Timer delete = newTimer(timerName); + for (IndexedType i : entries) { + try(Timer.Context ctx = delete.time()) { + db.delete(i.getClass(), i.key); + } + } + } + + private List createSimpleType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + SimpleType t = new SimpleType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private List createIndexedType() { + List entries = new ArrayList<>(); + for (int i = 0; i < COUNT; i++) { + IndexedType t = new IndexedType(); + t.key = IDGEN.getAndIncrement(); + t.name = "name" + (t.key % 1024); + entries.add(t); + } + return entries; + } + + private Timer newTimer(String name) { + assertNull("Timer already exists: " + name, metrics.getTimers().get(name)); + return metrics.timer(name); + } + + public static class SimpleType { + + @KVIndex + public int key; + + public String name; + + } + + public static class IndexedType { + + @KVIndex + public int key; + + @KVIndex("name") + public String name; + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java new file mode 100644 index 0000000000000..93409712986ca --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBIteratorSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; + +import org.apache.commons.io.FileUtils; +import org.junit.AfterClass; + +public class LevelDBIteratorSuite extends DBIteratorSuite { + + private static File dbpath; + private static LevelDB db; + + @AfterClass + public static void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Override + protected KVStore createStore() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + return db; + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java new file mode 100644 index 0000000000000..ee1c397c08573 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBSuite.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.apache.commons.io.FileUtils; +import org.iq80.leveldb.DBIterator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBSuite { + + private LevelDB db; + private File dbpath; + + @After + public void cleanup() throws Exception { + if (db != null) { + db.close(); + } + if (dbpath != null) { + FileUtils.deleteQuietly(dbpath); + } + } + + @Before + public void setup() throws Exception { + dbpath = File.createTempFile("test.", ".ldb"); + dbpath.delete(); + db = new LevelDB(dbpath); + } + + @Test + public void testReopenAndVersionCheckDb() throws Exception { + db.close(); + db = null; + assertTrue(dbpath.exists()); + + db = new LevelDB(dbpath); + assertEquals(LevelDB.STORE_VERSION, + db.serializer.deserializeLong(db.db().get(LevelDB.STORE_VERSION_KEY))); + db.db().put(LevelDB.STORE_VERSION_KEY, db.serializer.serialize(LevelDB.STORE_VERSION + 1)); + db.close(); + db = null; + + try { + db = new LevelDB(dbpath); + fail("Should have failed version check."); + } catch (UnsupportedStoreVersionException e) { + // Expected. + } + } + + @Test + public void testObjectWriteReadDelete() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + try { + db.read(CustomType1.class, t.key); + fail("Expected exception for non-existant object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + db.write(t); + assertEquals(t, db.read(t.getClass(), t.key)); + assertEquals(1L, db.count(t.getClass())); + + db.delete(t.getClass(), t.key); + try { + db.read(t.getClass(), t.key); + fail("Expected exception for deleted object."); + } catch (NoSuchElementException nsee) { + // Expected. + } + + // Look into the actual DB and make sure that all the keys related to the type have been + // removed. + assertEquals(0, countKeys(t.getClass())); + } + + @Test + public void testMultipleObjectWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.id = "id"; + t2.name = "name2"; + t2.child = "child2"; + + db.write(t1); + db.write(t2); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(2L, db.count(t1.getClass())); + + // There should be one "id" index entry with two values. + assertEquals(2, db.count(t1.getClass(), "id", t1.id)); + + // Delete the first entry; now there should be 3 remaining keys, since one of the "name" + // index entries should have been removed. + db.delete(t1.getClass(), t1.key); + + // Make sure there's a single entry in the "id" index now. + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + + // Delete the remaining entry, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + } + + @Test + public void testMultipleTypesWriteReadDelete() throws Exception { + CustomType1 t1 = new CustomType1(); + t1.key = "1"; + t1.id = "id"; + t1.name = "name1"; + t1.child = "child1"; + + IntKeyType t2 = new IntKeyType(); + t2.key = 2; + t2.id = "2"; + t2.values = Arrays.asList("value1", "value2"); + + ArrayKeyIndexType t3 = new ArrayKeyIndexType(); + t3.key = new int[] { 42, 84 }; + t3.id = new String[] { "id1", "id2" }; + + db.write(t1); + db.write(t2); + db.write(t3); + + assertEquals(t1, db.read(t1.getClass(), t1.key)); + assertEquals(t2, db.read(t2.getClass(), t2.key)); + assertEquals(t3, db.read(t3.getClass(), t3.key)); + + // There should be one "id" index with a single entry for each type. + assertEquals(1, db.count(t1.getClass(), "id", t1.id)); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the first entry; this should not affect the entries for the second type. + db.delete(t1.getClass(), t1.key); + assertEquals(0, countKeys(t1.getClass())); + assertEquals(1, db.count(t2.getClass(), "id", t2.id)); + assertEquals(1, db.count(t3.getClass(), "id", t3.id)); + + // Delete the remaining entries, make sure all data is gone. + db.delete(t2.getClass(), t2.key); + assertEquals(0, countKeys(t2.getClass())); + + db.delete(t3.getClass(), t3.key); + assertEquals(0, countKeys(t3.getClass())); + } + + @Test + public void testMetadata() throws Exception { + assertNull(db.getMetadata(CustomType1.class)); + + CustomType1 t = new CustomType1(); + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.setMetadata(t); + assertEquals(t, db.getMetadata(CustomType1.class)); + + db.setMetadata(null); + assertNull(db.getMetadata(CustomType1.class)); + } + + @Test + public void testUpdate() throws Exception { + CustomType1 t = new CustomType1(); + t.key = "key"; + t.id = "id"; + t.name = "name"; + t.child = "child"; + + db.write(t); + + t.name = "anotherName"; + + db.write(t); + + assertEquals(1, db.count(t.getClass())); + assertEquals(1, db.count(t.getClass(), "name", "anotherName")); + assertEquals(0, db.count(t.getClass(), "name", "name")); + } + + @Test + public void testSkip() throws Exception { + for (int i = 0; i < 10; i++) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.child = "child" + i; + + db.write(t); + } + + KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); + assertTrue(it.hasNext()); + assertTrue(it.skip(5)); + assertEquals("key5", it.next().key); + assertTrue(it.skip(3)); + assertEquals("key9", it.next().key); + assertFalse(it.hasNext()); + } + + private int countKeys(Class type) throws Exception { + byte[] prefix = db.getTypeInfo(type).keyPrefix(); + int count = 0; + + DBIterator it = db.db().iterator(); + it.seek(prefix); + + while (it.hasNext()) { + byte[] key = it.next().getKey(); + if (LevelDBIterator.startsWith(key, prefix)) { + count++; + } + } + + return count; + } + + public static class IntKeyType { + + @KVIndex + public int key; + + @KVIndex("id") + public String id; + + public List values; + + @Override + public boolean equals(Object o) { + if (o instanceof IntKeyType) { + IntKeyType other = (IntKeyType) o; + return key == other.key && id.equals(other.id) && values.equals(other.values); + } + return false; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + } + + public static class ArrayKeyIndexType { + + @KVIndex + public int[] key; + + @KVIndex("id") + public String[] id; + + @Override + public boolean equals(Object o) { + if (o instanceof ArrayKeyIndexType) { + ArrayKeyIndexType other = (ArrayKeyIndexType) o; + return Arrays.equals(key, other.key) && Arrays.equals(id, other.id); + } + return false; + } + + @Override + public int hashCode() { + return key.hashCode(); + } + + } + +} diff --git a/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java new file mode 100644 index 0000000000000..8e6196506c6a8 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/kvstore/LevelDBTypeInfoSuite.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kvstore; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class LevelDBTypeInfoSuite { + + @Test + public void testIndexAnnotation() throws Exception { + KVTypeInfo ti = new KVTypeInfo(CustomType1.class); + assertEquals(5, ti.indices().count()); + + CustomType1 t1 = new CustomType1(); + t1.key = "key"; + t1.id = "id"; + t1.name = "name"; + t1.num = 42; + t1.child = "child"; + + assertEquals(t1.key, ti.getIndexValue(KVIndex.NATURAL_INDEX_NAME, t1)); + assertEquals(t1.id, ti.getIndexValue("id", t1)); + assertEquals(t1.name, ti.getIndexValue("name", t1)); + assertEquals(t1.num, ti.getIndexValue("int", t1)); + assertEquals(t1.child, ti.getIndexValue("child", t1)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex() throws Exception { + newTypeInfo(NoNaturalIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoNaturalIndex2() throws Exception { + newTypeInfo(NoNaturalIndex2.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testDuplicateIndex() throws Exception { + newTypeInfo(DuplicateIndex.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyIndexName() throws Exception { + newTypeInfo(EmptyIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexName() throws Exception { + newTypeInfo(IllegalIndexName.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalIndexMethod() throws Exception { + newTypeInfo(IllegalIndexMethod.class); + } + + @Test + public void testKeyClashes() throws Exception { + LevelDBTypeInfo ti = newTypeInfo(CustomType1.class); + + CustomType1 t1 = new CustomType1(); + t1.key = "key1"; + t1.name = "a"; + + CustomType1 t2 = new CustomType1(); + t2.key = "key2"; + t2.name = "aa"; + + CustomType1 t3 = new CustomType1(); + t3.key = "key3"; + t3.name = "aaa"; + + // Make sure entries with conflicting names are sorted correctly. + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t2)); + assertBefore(ti.index("name").entityKey(null, t1), ti.index("name").entityKey(null, t3)); + assertBefore(ti.index("name").entityKey(null, t2), ti.index("name").entityKey(null, t3)); + } + + @Test + public void testNumEncoding() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertEquals("+=00000001", new String(idx.toKey(1), UTF_8)); + assertEquals("+=00000010", new String(idx.toKey(16), UTF_8)); + assertEquals("+=7fffffff", new String(idx.toKey(Integer.MAX_VALUE), UTF_8)); + + assertBefore(idx.toKey(1), idx.toKey(2)); + assertBefore(idx.toKey(-1), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(2)); + assertBefore(idx.toKey(-11), idx.toKey(-1)); + assertBefore(idx.toKey(1), idx.toKey(11)); + assertBefore(idx.toKey(Integer.MIN_VALUE), idx.toKey(Integer.MAX_VALUE)); + + assertBefore(idx.toKey(1L), idx.toKey(2L)); + assertBefore(idx.toKey(-1L), idx.toKey(2L)); + assertBefore(idx.toKey(Long.MIN_VALUE), idx.toKey(Long.MAX_VALUE)); + + assertBefore(idx.toKey((short) 1), idx.toKey((short) 2)); + assertBefore(idx.toKey((short) -1), idx.toKey((short) 2)); + assertBefore(idx.toKey(Short.MIN_VALUE), idx.toKey(Short.MAX_VALUE)); + + assertBefore(idx.toKey((byte) 1), idx.toKey((byte) 2)); + assertBefore(idx.toKey((byte) -1), idx.toKey((byte) 2)); + assertBefore(idx.toKey(Byte.MIN_VALUE), idx.toKey(Byte.MAX_VALUE)); + + byte prefix = LevelDBTypeInfo.ENTRY_PREFIX; + assertSame(new byte[] { prefix, LevelDBTypeInfo.FALSE }, idx.toKey(false)); + assertSame(new byte[] { prefix, LevelDBTypeInfo.TRUE }, idx.toKey(true)); + } + + @Test + public void testArrayIndices() throws Exception { + LevelDBTypeInfo.Index idx = newTypeInfo(CustomType1.class).indices().iterator().next(); + + assertBefore(idx.toKey(new String[] { "str1" }), idx.toKey(new String[] { "str2" })); + assertBefore(idx.toKey(new String[] { "str1", "str2" }), + idx.toKey(new String[] { "str1", "str3" })); + + assertBefore(idx.toKey(new int[] { 1 }), idx.toKey(new int[] { 2 })); + assertBefore(idx.toKey(new int[] { 1, 2 }), idx.toKey(new int[] { 1, 3 })); + } + + private LevelDBTypeInfo newTypeInfo(Class type) throws Exception { + return new LevelDBTypeInfo(null, type, type.getName().getBytes(UTF_8)); + } + + private void assertBefore(byte[] key1, byte[] key2) { + assertBefore(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + private void assertBefore(String str1, String str2) { + assertTrue(String.format("%s < %s failed", str1, str2), str1.compareTo(str2) < 0); + } + + private void assertSame(byte[] key1, byte[] key2) { + assertEquals(new String(key1, UTF_8), new String(key2, UTF_8)); + } + + public static class NoNaturalIndex { + + public String id; + + } + + public static class NoNaturalIndex2 { + + @KVIndex("id") + public String id; + + } + + public static class DuplicateIndex { + + @KVIndex + public String key; + + @KVIndex("id") + public String id; + + @KVIndex("id") + public String id2; + + } + + public static class EmptyIndexName { + + @KVIndex("") + public String id; + + } + + public static class IllegalIndexName { + + @KVIndex("__invalid") + public String id; + + } + + public static class IllegalIndexMethod { + + @KVIndex("id") + public String id(boolean illegalParam) { + return null; + } + + } + +} diff --git a/common/kvstore/src/test/resources/log4j.properties b/common/kvstore/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/common/kvstore/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/pom.xml b/pom.xml index 0533a8dcf2e0a..6835ea14cd42b 100644 --- a/pom.xml +++ b/pom.xml @@ -83,6 +83,7 @@ common/sketch + common/kvstore common/network-common common/network-shuffle common/unsafe @@ -441,6 +442,11 @@ httpcore ${commons.httpcore.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + org.seleniumhq.selenium selenium-java @@ -588,6 +594,11 @@ metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-core + ${fasterxml.jackson.version} + com.fasterxml.jackson.core jackson-databind diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b5362ec1ae452..89b0c7a3ab7b0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -50,10 +50,10 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "tags", "sketch" + "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, @@ -310,7 +310,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010 + unsafe, tags, sqlKafka010, kvstore ).contains(x) } From c92949ac23652e2c3a0c97fdf3d6e016f9d01dda Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Jun 2017 22:50:06 -0700 Subject: [PATCH 087/133] [SPARK-20972][SQL] rename HintInfo.isBroadcastable to broadcast ## What changes were proposed in this pull request? `HintInfo.isBroadcastable` is actually not an accurate name, it's used to force the planner to broadcast a plan no matter what the data size is, via the hint mechanism. I think `forceBroadcast` is a better name. And `isBroadcastable` only have 2 possible values: `Some(true)` and `None`, so we can just use boolean type for it. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18189 from cloud-fan/stats. --- .../sql/catalyst/analysis/ResolveHints.scala | 6 +++--- .../sql/catalyst/plans/logical/hints.scala | 16 +++++++-------- .../catalyst/analysis/ResolveHintsSuite.scala | 20 +++++++++---------- .../BasicStatsEstimationSuite.scala | 6 +++--- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/functions.scala | 2 +- 6 files changed, 25 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 62a3482d9fac1..f068bce3e9b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -58,9 +58,9 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(plan, HintInfo(broadcast = true)) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(plan, HintInfo(broadcast = true)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. @@ -89,7 +89,7 @@ object ResolveHints { case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true))) + ResolvedHint(h.child, HintInfo(broadcast = true)) } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index d16fae56b3d4a..e49970df80457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -51,19 +51,17 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) } -case class HintInfo( - isBroadcastable: Option[Boolean] = None) { +case class HintInfo(broadcast: Boolean = false) { /** Must be called when computing stats for a join operator to reset hints. */ - def resetForJoin(): HintInfo = copy( - isBroadcastable = None - ) + def resetForJoin(): HintInfo = copy(broadcast = false) override def toString: String = { - if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) { - "none" - } else { - isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("") + val hints = scala.collection.mutable.ArrayBuffer.empty[String] + if (broadcast) { + hints += "broadcast" } + + if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 3d5148008c628..9782b5fb0d266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -36,17 +36,17 @@ class ResolveHintsSuite extends AnalysisTest { test("case-sensitive or insensitive parameters") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = true) checkAnalysis( @@ -58,28 +58,28 @@ class ResolveHintsSuite extends AnalysisTest { test("multiple broadcast hint aliases") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), - ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None), + Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))), - ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze, + ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))), + ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation, HintInfo(broadcast = true)), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. @@ -104,7 +104,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true))) + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true)) .select('a).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 2afea6dd3d37c..833f5a71994f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -45,11 +45,11 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { expectedStatsCboOn = filterStatsCboOn, expectedStatsCboOff = filterStatsCboOff) - val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true))) + val broadcastHint = ResolvedHint(filter, HintInfo(broadcast = true)) checkStats( broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))), - expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true))) + expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(broadcast = true)), + expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(broadcast = true)) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f13294c925e36..ea86f6e00fefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.isBroadcastable.getOrElse(false) || + plan.stats(conf).hints.broadcast || (plan.stats(conf).sizeInBytes >= 0 && plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 67ec1325b321e..8d0a8c2178803 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1020,7 +1020,7 @@ object functions { */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc) + ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc) } /** From cb83ca1433c865cb0aef973df2b872a83671acfd Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 6 Jun 2017 22:51:10 -0700 Subject: [PATCH 088/133] [SPARK-20854][TESTS] Removing duplicate test case ## What changes were proposed in this pull request? Removed a duplicate case in "SPARK-20854: select hint syntax with expressions" ## How was this patch tested? Existing tests. Author: Bogdan Raducanu Closes #18217 from bogdanrdc/SPARK-20854-2. --- .../spark/sql/catalyst/parser/PlanParserSuite.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d004d04569772..fef39a5b6a32f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -575,14 +575,6 @@ class PlanParserSuite extends PlanTest { ) ) - comparePlans( - parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"), - UnresolvedHint("HINT1", Seq($"a", - UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)), - table("t").select(star()) - ) - ) - comparePlans( parsePlan("SELECT /*+ HINT1(a, 5, 'a', b) */ * from t"), UnresolvedHint("HINT1", Seq($"a", Literal(5), Literal("a"), $"b"), From 3218505a0b4fc37026a71a262c931f06a60c7bf6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 7 Jun 2017 08:50:36 +0100 Subject: [PATCH 089/133] [MINOR][DOC] Update deprecation notes on Python/Hadoop/Scala. ## What changes were proposed in this pull request? We had better update the deprecation notes about Python 2.6, Hadoop (before 2.6.5) and Scala 2.10 in [2.2.0-RC4](http://people.apache.org/~pwendell/spark-releases/spark-2.2.0-rc4-docs/) documentation. Since this is a doc only update, I think we can update the doc during publishing. **BEFORE (2.2.0-RC4)** ![before](https://cloud.githubusercontent.com/assets/9700541/26799758/aea0dc06-49eb-11e7-8ca3-ed8ce1cc6147.png) **AFTER** ![after](https://cloud.githubusercontent.com/assets/9700541/26799761/b3fef818-49eb-11e7-83c5-334f0e4768ed.png) ## How was this patch tested? Manual. ``` SKIP_API=1 jekyll build ``` Author: Dongjoon Hyun Closes #18207 from dongjoon-hyun/minor_doc_deprecation. --- docs/index.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/index.md b/docs/index.md index 960b968454d0e..f7b5863957ce2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -26,15 +26,13 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8+, Python 2.6+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). -Note that support for Java 7 was removed as of Spark 2.2.0. +Note that support for Java 7, Python 2.6 and old Hadoop versions before 2.6.5 were removed as of Spark 2.2.0. -Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and support for -Scala 2.10 and versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0, and may be -removed in Spark 2.2.0. +Note that support for Scala 2.10 is deprecated as of Spark 2.1.0, and may be removed in Spark 2.3.0. # Running the Examples and Shell From 0ca69c4ccf9cd5934d9c73d15c0224342385d333 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Wed, 7 Jun 2017 10:18:40 +0100 Subject: [PATCH 090/133] [SPARK-20966][WEB-UI][SQL] Table data is not sorted by startTime time desc, time is not formatted and redundant code in JDBC/ODBC Server page. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. Question 1 : Table data is not sorted by startTime time desc in JDBC/ODBC Server page. fix before : ![2](https://cloud.githubusercontent.com/assets/26266482/26718483/bf4a0fa8-47b3-11e7-9a27-dc6a67165b16.png) fix after : ![21](https://cloud.githubusercontent.com/assets/26266482/26718544/eb7376c8-47b3-11e7-9117-1bc68dfec92c.png) 2. Question 2 : time is not formatted in JDBC/ODBC Server page. fix before : ![1](https://cloud.githubusercontent.com/assets/26266482/26718573/0497d86a-47b4-11e7-945b-582aaa103949.png) fix after : ![11](https://cloud.githubusercontent.com/assets/26266482/26718602/21371ad0-47b4-11e7-9587-c5114d10ab2c.png) 3. Question 3 : Redundant code in the ThriftServerSessionPage.scala. The function of 'generateSessionStatsTable' has not been used ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Author: 郭小龙 10207633 Author: guoxiaolongzte Closes #18186 from guoxiaolongzte/SPARK-20966. --- .../thriftserver/ui/ThriftServerPage.scala | 4 +- .../ui/ThriftServerSessionPage.scala | 38 +------------------ 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 2e0fa1ef77f88..17589cf44b998 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -72,7 +72,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.getExecutionList + val dataRows = listener.getExecutionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -142,7 +142,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = sessionList + val dataRows = sessionList.sortBy(_.startTimestamp).reverse val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 38b8605745752..5cd2fdf6437c2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -66,7 +66,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) val timeSinceStart = System.currentTimeMillis() - startTime.getTime
  • - Started at: {startTime.toString} + Started at: {formatDate(startTime)}
  • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -147,42 +147,6 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) {errorSummary}{details} } - /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { - val sessionList = listener.getSessionList - val numBatches = sessionList.size - val table = if (numBatches > 0) { - val dataRows = - sessionList.sortBy(_.startTimestamp).reverse.map ( session => - Seq( - session.userName, - session.ip, - session.sessionId, - formatDate(session.startTimestamp), - formatDate(session.finishTimestamp), - formatDurationOption(Some(session.totalTime)), - session.totalExecution.toString - ) - ).toSeq - val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", - "Total Execute") - Some(listingTable(headerRow, dataRows)) - } else { - None - } - - val content = -
    Session Statistics
    ++ -
    -
      - {table.getOrElse("No statistics have been generated yet.")} -
    -
    - - content - } - - /** * Returns a human-readable string representing a duration such as "5 second 35 ms" */ From 847efe12656756f9ad6a4dc14bd183ac1a0760a6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 8 Jun 2017 10:56:23 +0100 Subject: [PATCH 091/133] [SPARK-20914][DOCS] Javadoc contains code that is invalid ## What changes were proposed in this pull request? Fix Java, Scala Dataset examples in scaladoc, which didn't compile. ## How was this patch tested? Existing compilation/test Author: Sean Owen Closes #18215 from srowen/SPARK-20914. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8abec85ee102a..f7637e005f317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -131,7 +131,7 @@ private[sql] object Dataset { * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) - * .groupBy(department("name"), "gender") + * .groupBy(department("name"), people("gender")) * .agg(avg(people("salary")), max(people("age"))) * }}} * @@ -141,9 +141,9 @@ private[sql] object Dataset { * Dataset people = spark.read().parquet("..."); * Dataset department = spark.read().parquet("..."); * - * people.filter("age".gt(30)) - * .join(department, people.col("deptId").equalTo(department("id"))) - * .groupBy(department.col("name"), "gender") + * people.filter(people.col("age").gt(30)) + * .join(department, people.col("deptId").equalTo(department.col("id"))) + * .groupBy(department.col("name"), people.col("gender")) * .agg(avg(people.col("salary")), max(people.col("age"))); * }}} * From 9be7945861a39974172826fc2d27ba8f5b8c3827 Mon Sep 17 00:00:00 2001 From: 10087686 Date: Thu, 8 Jun 2017 10:58:09 +0100 Subject: [PATCH 092/133] [SPARK-21006][TESTS] Create rpcEnv and run later needs shutdown and awaitTermination Signed-off-by: 10087686 ## What changes were proposed in this pull request? When run test("port conflict") case, we need run anotherEnv.shutdown() and anotherEnv.awaitTermination() for free resource. (Please fill in changes proposed in this fix) ## How was this patch tested? run RpcEnvSuit.scala Utest (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 10087686 Closes #18226 from wangjiaochun/master. --- core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 31d9dd3de8acc..59d8c14d74e30 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -633,7 +633,12 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("port conflict") { val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) - assert(anotherEnv.address.port != env.address.port) + try { + assert(anotherEnv.address.port != env.address.port) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } } private def testSend(conf: SparkConf): Unit = { From b771fed73f68826c6247d900307edf56c53d6522 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Jun 2017 11:14:42 +0100 Subject: [PATCH 093/133] [INFRA] Close stale PRs # What changes were proposed in this pull request? This PR proposes to close stale PRs, mostly the same instances with https://github.com/apache/spark/pull/18017 Closes #11459 Closes #13833 Closes #13720 Closes #12506 Closes #12456 Closes #12252 Closes #17689 Closes #17791 Closes #18163 Closes #17640 Closes #17926 Closes #18163 Closes #12506 Closes #18044 Closes #14036 Closes #15831 Closes #14461 Closes #17638 Closes #18222 Added: Closes #18045 Closes #18061 Closes #18010 Closes #18041 Closes #18124 Closes #18130 Closes #12217 Added: Closes #16291 Closes #17480 Closes #14995 Added: Closes #12835 Closes #17141 ## How was this patch tested? N/A Author: hyukjinkwon Closes #18223 from HyukjinKwon/close-stale-prs. From 55b8cfe6e6a6759d65bf219ff570fd6154197ec4 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Thu, 8 Jun 2017 09:55:43 -0700 Subject: [PATCH 094/133] [SPARK-19185][DSTREAM] Make Kafka consumer cache configurable ## What changes were proposed in this pull request? Add a new property `spark.streaming.kafka.consumer.cache.enabled` that allows users to enable or disable the cache for Kafka consumers. This property can be especially handy in cases where issues like SPARK-19185 get hit, for which there isn't a solution committed yet. By default, the cache is still on, so this change doesn't change any out-of-box behavior. ## How was this patch tested? Running unit tests Author: Mark Grover Author: Mark Grover Closes #18234 from markgrover/spark-19185. --- docs/streaming-kafka-0-10-integration.md | 4 +++- .../streaming/kafka010/DirectKafkaInputDStream.scala | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index 92c296a9e6bd3..386066a85749f 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -91,7 +91,9 @@ The new Kafka consumer API will pre-fetch messages into buffers. Therefore it i In most cases, you should use `LocationStrategies.PreferConsistent` as shown above. This will distribute partitions evenly across available executors. If your executors are on the same hosts as your Kafka brokers, use `PreferBrokers`, which will prefer to schedule partitions on the Kafka leader for that partition. Finally, if you have a significant skew in load among partitions, use `PreferFixed`. This allows you to specify an explicit mapping of partitions to hosts (any unspecified partitions will use a consistent location). -The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` +The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity`. + +If you would like to disable the caching for Kafka consumers, you can set `spark.streaming.kafka.consumer.cache.enabled` to `false`. Disabling the cache may be needed to workaround the problem described in SPARK-19185. This property may be removed in later versions of Spark, once SPARK-19185 is resolved. The cache is keyed by topicpartition and group.id, so use a **separate** `group.id` for each call to `createDirectStream`. diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 6d6983c4bd419..9a4a1cf32a480 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -213,8 +213,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( val fo = currentOffsets(tp) OffsetRange(tp.topic, tp.partition, fo, uo) } - val rdd = new KafkaRDD[K, V]( - context.sparkContext, executorKafkaParams, offsetRanges.toArray, getPreferredHosts, true) + val useConsumerCache = context.conf.getBoolean("spark.streaming.kafka.consumer.cache.enabled", + true) + val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray, + getPreferredHosts, useConsumerCache) // Report the record number and metadata of this batch interval to InputInfoTracker. val description = offsetRanges.filter { offsetRange => @@ -316,7 +318,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( b.map(OffsetRange(_)), getPreferredHosts, // during restore, it's possible same partition will be consumed from multiple - // threads, so dont use cache + // threads, so do not use cache. false ) } From 1a527bde49753535e6b86c18751f50c19a55f0d0 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 8 Jun 2017 12:10:31 -0700 Subject: [PATCH 095/133] [SPARK-20976][SQL] Unify Error Messages for FAILFAST mode ### What changes were proposed in this pull request? Before 2.2, we indicate the job was terminated because of `FAILFAST` mode. ``` Malformed line in FAILFAST mode: {"a":{, b:3} ``` If possible, we should keep it. This PR is to unify the error messages. ### How was this patch tested? Modified the existing messages. Author: Xiao Li Closes #18196 from gatorsmile/messFailFast. --- .../sql/catalyst/json/JacksonParser.scala | 2 +- .../datasources/FailureSafeParser.scala | 4 +++- .../datasources/json/JsonInferSchema.scala | 9 ++++++--- .../datasources/json/JsonSuite.scala | 20 ++++++++++--------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 4ed6728994193..bd144c9575c72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -278,7 +278,7 @@ class JacksonParser( // We cannot parse this token based on the given data type. So, we throw a // RuntimeException and this exception will be caught by `parse` method. throw new RuntimeException( - s"Failed to parse a value for data type $dataType (current token: $token).") + s"Failed to parse a value for data type ${dataType.catalogString} (current token: $token).") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 159aef220be15..43591a9ff524a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util._ @@ -65,7 +66,8 @@ class FailureSafeParser[IN]( case DropMalformedMode => Iterator.empty case FailFastMode => - throw e.cause + throw new SparkException("Malformed records are detected in record parsing. " + + s"Parse Mode: ${FailFastMode.name}.", e.cause) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index fb632cf2bb70e..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -21,6 +21,7 @@ import java.util.Comparator import com.fasterxml.jackson.core._ +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil @@ -61,7 +62,8 @@ private[sql] object JsonInferSchema { case DropMalformedMode => None case FailFastMode => - throw e + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}.", e) } } } @@ -231,8 +233,9 @@ private[sql] object JsonInferSchema { case FailFastMode => // If `other` is not struct type, consider it as malformed one and throws an exception. - throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + - s" but ${other.catalogString} was found.") + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " + + s"Struct types are expected, but `${other.catalogString}` was found.") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e66a60d7503f3..65472cda9c1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1036,24 +1036,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Corrupt records: FAILFAST mode") { - val schema = StructType( - StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - } - assert(exceptionOne.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionOne.contains( + "Malformed records are detected in schema inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read .option("mode", "FAILFAST") - .schema(schema) + .schema("a string") .json(corruptRecords) .collect() - } - assert(exceptionTwo.getMessage.contains("JsonParseException")) + }.getMessage + assert(exceptionTwo.contains( + "Malformed records are detected in record parsing. Parse Mode: FAILFAST.")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1944,7 +1944,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("mode", "FAILFAST") .json(path) } - assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) + assert(exceptionOne.getMessage.contains("Malformed records are detected in schema " + + "inference. Parse Mode: FAILFAST.")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1954,7 +1955,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Failed to parse a value")) + assert(exceptionTwo.getMessage.contains("Malformed records are detected in record " + + "parsing. Parse Mode: FAILFAST.")) } } From 6e95897e8806b359e82841eb8de20146fed4f3f9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 8 Jun 2017 16:46:56 -0700 Subject: [PATCH 096/133] [SPARK-20954][SQL] `DESCRIBE [EXTENDED]` result should be compatible with previous Spark ## What changes were proposed in this pull request? After [SPARK-20067](https://issues.apache.org/jira/browse/SPARK-20067), `DESCRIBE` and `DESCRIBE EXTENDED` shows the following result. This is incompatible with Spark 2.1.1. This PR removes the column header line in case of those command. **MASTER** and **BRANCH-2.2** ```scala scala> sql("desc t").show(false) +----------+---------+-------+ |col_name |data_type|comment| +----------+---------+-------+ |# col_name|data_type|comment| |a |int |null | +----------+---------+-------+ ``` **SPARK 2.1.1** and **this PR** ```scala scala> sql("desc t").show(false) +--------+---------+-------+ |col_name|data_type|comment| +--------+---------+-------+ |a |int |null | +--------+---------+-------+ ``` ## How was this patch tested? Pass the Jenkins with the updated test suites. Author: Dongjoon Hyun Closes #18203 from dongjoon-hyun/SPARK-20954. --- .../spark/sql/execution/command/tables.scala | 17 +++++++++++------ .../sql-tests/results/change-column.sql.out | 9 --------- .../describe-table-after-alter-table.sql.out | 5 ----- .../sql-tests/results/describe.sql.out | 17 ----------------- .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 5 files changed, 12 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9ccd6792e5da4..b937a8a9f375b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -522,15 +522,15 @@ case class DescribeTableCommand( throw new AnalysisException( s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") } - describeSchema(catalog.lookupRelation(table).schema, result) + describeSchema(catalog.lookupRelation(table).schema, result, header = false) } else { val metadata = catalog.getTableMetadata(table) if (metadata.schema.isEmpty) { // In older version(prior to 2.1) of Spark, the table schema can be empty and should be // inferred at runtime. We should still support it. - describeSchema(sparkSession.table(metadata.identifier).schema, result) + describeSchema(sparkSession.table(metadata.identifier).schema, result, header = false) } else { - describeSchema(metadata.schema, result) + describeSchema(metadata.schema, result, header = false) } describePartitionInfo(metadata, result) @@ -550,7 +550,7 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - describeSchema(table.partitionSchema, buffer) + describeSchema(table.partitionSchema, buffer, header = true) } } @@ -601,8 +601,13 @@ case class DescribeTableCommand( table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } - private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + private def describeSchema( + schema: StructType, + buffer: ArrayBuffer[Row], + header: Boolean): Unit = { + if (header) { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + } schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 678a3f0f0a3c6..ba8bc936f0c79 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,7 +15,6 @@ DESC test_change -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a int b string c int @@ -35,7 +34,6 @@ DESC test_change -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a int b string c int @@ -55,7 +53,6 @@ DESC test_change -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a int b string c int @@ -94,7 +91,6 @@ DESC test_change -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a int b string c int @@ -129,7 +125,6 @@ DESC test_change -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -148,7 +143,6 @@ DESC test_change -- !query 14 schema struct -- !query 14 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -168,7 +162,6 @@ DESC test_change -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -193,7 +186,6 @@ DESC test_change -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -237,7 +229,6 @@ DESC test_change -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out index 1cc11c475bc40..eece00d603db4 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -15,7 +15,6 @@ DESC FORMATTED table_with_comment -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a string b int c string @@ -45,7 +44,6 @@ DESC FORMATTED table_with_comment -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a string b int c string @@ -84,7 +82,6 @@ DESC FORMATTED table_comment -- !query 6 schema struct -- !query 6 output -# col_name data_type comment a string b int @@ -111,7 +108,6 @@ DESC formatted table_comment -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a string b int @@ -139,7 +135,6 @@ DESC FORMATTED table_comment -- !query 10 schema struct -- !query 10 output -# col_name data_type comment a string b int diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index de10b29f3c65b..46d32bbc52247 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -54,7 +54,6 @@ DESCRIBE t -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a string b int c string @@ -70,7 +69,6 @@ DESC default.t -- !query 6 schema struct -- !query 6 output -# col_name data_type comment a string b int c string @@ -86,7 +84,6 @@ DESC TABLE t -- !query 7 schema struct -- !query 7 output -# col_name data_type comment a string b int c string @@ -102,7 +99,6 @@ DESC FORMATTED t -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a string b int c string @@ -132,7 +128,6 @@ DESC EXTENDED t -- !query 9 schema struct -- !query 9 output -# col_name data_type comment a string b int c string @@ -162,7 +157,6 @@ DESC t PARTITION (c='Us', d=1) -- !query 10 schema struct -- !query 10 output -# col_name data_type comment a string b int c string @@ -178,7 +172,6 @@ DESC EXTENDED t PARTITION (c='Us', d=1) -- !query 11 schema struct -- !query 11 output -# col_name data_type comment a string b int c string @@ -206,7 +199,6 @@ DESC FORMATTED t PARTITION (c='Us', d=1) -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a string b int c string @@ -268,7 +260,6 @@ DESC temp_v -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a string b int c string @@ -280,7 +271,6 @@ DESC TABLE temp_v -- !query 17 schema struct -- !query 17 output -# col_name data_type comment a string b int c string @@ -292,7 +282,6 @@ DESC FORMATTED temp_v -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a string b int c string @@ -304,7 +293,6 @@ DESC EXTENDED temp_v -- !query 19 schema struct -- !query 19 output -# col_name data_type comment a string b int c string @@ -316,7 +304,6 @@ DESC temp_Data_Source_View -- !query 20 schema struct -- !query 20 output -# col_name data_type comment intType int test comment test1 stringType string dateType date @@ -349,7 +336,6 @@ DESC v -- !query 22 schema struct -- !query 22 output -# col_name data_type comment a string b int c string @@ -361,7 +347,6 @@ DESC TABLE v -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a string b int c string @@ -373,7 +358,6 @@ DESC FORMATTED v -- !query 24 schema struct -- !query 24 output -# col_name data_type comment a string b int c string @@ -396,7 +380,6 @@ DESC EXTENDED v -- !query 25 schema struct -- !query 25 output -# col_name data_type comment a string b int c string diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index ab931b94987d3..aca964907d4cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -806,7 +806,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil + Row("a", "int", "test") :: Nil ) } } From 2a23cdd078a7409d0bb92cf27718995766c41b1d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 8 Jun 2017 18:08:25 -0700 Subject: [PATCH 097/133] [SPARK-20863] Add metrics/instrumentation to LiveListenerBus ## What changes were proposed in this pull request? This patch adds Coda Hale metrics for instrumenting the `LiveListenerBus` in order to track the number of events received, dropped, and processed. In addition, it adds per-SparkListener-subclass timers to track message processing time. This is useful for identifying when slow third-party SparkListeners cause performance bottlenecks. See the new `LiveListenerBusMetrics` for a complete description of the new metrics. ## How was this patch tested? New tests in SparkListenerSuite, including a test to ensure proper counting of dropped listener events. Author: Josh Rosen Closes #18083 from JoshRosen/listener-bus-metrics. --- .../scala/org/apache/spark/SparkContext.scala | 7 +- .../spark/internal/config/package.scala | 6 ++ .../spark/scheduler/LiveListenerBus.scala | 101 ++++++++++++++++-- .../org/apache/spark/util/ListenerBus.scala | 33 +++++- .../scheduler/EventLoggingListenerSuite.scala | 7 +- .../spark/scheduler/SparkListenerSuite.scala | 93 +++++++++++++--- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../spark/ui/storage/StorageTabSuite.scala | 5 +- .../streaming/ReceivedBlockHandlerSuite.scala | 4 +- 10 files changed, 220 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1a2443f7ee78d..b2a26c51d4de1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -195,6 +195,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _conf: SparkConf = _ private var _eventLogDir: Option[URI] = None private var _eventLogCodec: Option[String] = None + private var _listenerBus: LiveListenerBus = _ private var _env: SparkEnv = _ private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ @@ -247,7 +248,7 @@ class SparkContext(config: SparkConf) extends Logging { def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events - private[spark] val listenerBus = new LiveListenerBus(this) + private[spark] def listenerBus: LiveListenerBus = _listenerBus // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( @@ -423,6 +424,8 @@ class SparkContext(config: SparkConf) extends Logging { if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") + _listenerBus = new LiveListenerBus(_conf) + // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. _jobProgressListener = new JobProgressListener(_conf) @@ -2388,7 +2391,7 @@ class SparkContext(config: SparkConf) extends Logging { } } - listenerBus.start() + listenerBus.start(this, _env.metricsSystem) _listenerBusStarted = true } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4ad04b04c312d..7827e6760f355 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -158,6 +158,12 @@ package object config { .checkValue(_ > 0, "The capacity of listener bus event queue must not be negative") .createWithDefault(10000) + private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED = + ConfigBuilder("spark.scheduler.listenerbus.metrics.maxListenerClassesTimed") + .internal() + .intConf + .createWithDefault(128) + // This property sets the root namespace for metrics reporting private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace") .stringConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 801dfaa62306a..f0887e090b956 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -20,10 +20,16 @@ package org.apache.spark.scheduler import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.mutable import scala.util.DynamicVariable -import org.apache.spark.SparkContext +import com.codahale.metrics.{Counter, Gauge, MetricRegistry, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.util.Utils /** @@ -33,15 +39,20 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { +private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { + self => import LiveListenerBus._ + private var sparkContext: SparkContext = _ + // Cap the capacity of the event queue so we get an explicit error (rather than // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( - sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + private val eventQueue = + new LinkedBlockingQueue[SparkListenerEvent](conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + + private[spark] val metrics = new LiveListenerBusMetrics(conf, eventQueue) // Indicate if `start()` is called private val started = new AtomicBoolean(false) @@ -67,6 +78,7 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { LiveListenerBus.withinListenerThread.withValue(true) { + val timer = metrics.eventProcessingTime while (true) { eventLock.acquire() self.synchronized { @@ -82,7 +94,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa } return } - postToAll(event) + val timerContext = timer.time() + try { + postToAll(event) + } finally { + timerContext.stop() + } } finally { self.synchronized { processingEvent = false @@ -93,6 +110,10 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa } } + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + } + /** * Start sending events to attached listeners. * @@ -100,9 +121,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa * listens for any additional events asynchronously while the listener bus is still running. * This should only be called once. * + * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(): Unit = { + def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = { if (started.compareAndSet(false, true)) { + sparkContext = sc + metricsSystem.registerSource(metrics) listenerThread.start() } else { throw new IllegalStateException(s"$name already started!") @@ -115,12 +139,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa logError(s"$name has already stopped! Dropping event $event") return } + metrics.numEventsPosted.inc() val eventAdded = eventQueue.offer(event) if (eventAdded) { eventLock.release() } else { onDropEvent(event) - droppedEventsCounter.incrementAndGet() } val droppedEvents = droppedEventsCounter.get @@ -200,6 +224,8 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa * Note: `onDropEvent` can be called in any thread. */ def onDropEvent(event: SparkListenerEvent): Unit = { + metrics.numDroppedEvents.inc() + droppedEventsCounter.incrementAndGet() if (logDroppedEvent.compareAndSet(false, true)) { // Only log the following message once to avoid duplicated annoying logs. logError("Dropping SparkListenerEvent because no remaining room in event queue. " + @@ -217,3 +243,64 @@ private[spark] object LiveListenerBus { val name = "SparkListenerBus" } +private[spark] class LiveListenerBusMetrics( + conf: SparkConf, + queue: LinkedBlockingQueue[_]) + extends Source with Logging { + + override val sourceName: String = "LiveListenerBus" + override val metricRegistry: MetricRegistry = new MetricRegistry + + /** + * The total number of events posted to the LiveListenerBus. This is a count of the total number + * of events which have been produced by the application and sent to the listener bus, NOT a + * count of the number of events which have been processed and delivered to listeners (or dropped + * without being delivered). + */ + val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted")) + + /** + * The total number of events that were dropped without being delivered to listeners. + */ + val numDroppedEvents: Counter = metricRegistry.counter(MetricRegistry.name("numEventsDropped")) + + /** + * The amount of time taken to post a single event to all listeners. + */ + val eventProcessingTime: Timer = metricRegistry.timer(MetricRegistry.name("eventProcessingTime")) + + /** + * The number of messages waiting in the queue. + */ + val queueSize: Gauge[Int] = { + metricRegistry.register(MetricRegistry.name("queueSize"), new Gauge[Int]{ + override def getValue: Int = queue.size() + }) + } + + // Guarded by synchronization. + private val perListenerClassTimers = mutable.Map[String, Timer]() + + /** + * Returns a timer tracking the processing time of the given listener class. + * events processed by that listener. This method is thread-safe. + */ + def getTimerForListenerClass(cls: Class[_ <: SparkListenerInterface]): Option[Timer] = { + synchronized { + val className = cls.getName + val maxTimed = conf.get(LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED) + perListenerClassTimers.get(className).orElse { + if (perListenerClassTimers.size == maxTimed) { + logError(s"Not measuring processing time for listener class $className because a " + + s"maximum of $maxTimed listener classes are already timed.") + None + } else { + perListenerClassTimers(className) = + metricRegistry.timer(MetricRegistry.name("listenerProcessingTime", className)) + perListenerClassTimers.get(className) + } + } + } + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index fa5ad4e8d81e1..76a56298aaebc 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal +import com.codahale.metrics.Timer + import org.apache.spark.internal.Logging /** @@ -30,14 +32,22 @@ import org.apache.spark.internal.Logging */ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + private[this] val listenersPlusTimers = new CopyOnWriteArrayList[(L, Option[Timer])] + // Marked `private[spark]` for access in tests. - private[spark] val listeners = new CopyOnWriteArrayList[L] + private[spark] def listeners = listenersPlusTimers.asScala.map(_._1).asJava + + /** + * Returns a CodaHale metrics Timer for measuring the listener's event processing time. + * This method is intended to be overridden by subclasses. + */ + protected def getTimer(listener: L): Option[Timer] = None /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ final def addListener(listener: L): Unit = { - listeners.add(listener) + listenersPlusTimers.add((listener, getTimer(listener))) } /** @@ -45,7 +55,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * in any thread. */ final def removeListener(listener: L): Unit = { - listeners.remove(listener) + listenersPlusTimers.asScala.find(_._1 eq listener).foreach { listenerAndTimer => + listenersPlusTimers.remove(listenerAndTimer) + } } /** @@ -56,14 +68,25 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. - val iter = listeners.iterator + val iter = listenersPlusTimers.iterator while (iter.hasNext) { - val listener = iter.next() + val listenerAndMaybeTimer = iter.next() + val listener = listenerAndMaybeTimer._1 + val maybeTimer = listenerAndMaybeTimer._2 + val maybeTimerContext = if (maybeTimer.isDefined) { + maybeTimer.get.time() + } else { + null + } try { doPostEvent(listener, event) } catch { case NonFatal(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } finally { + if (maybeTimerContext != null) { + maybeTimerContext.stop() + } } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 4c3d0b102152c..4cae6c61118a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -25,12 +25,14 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.io._ +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{JsonProtocol, Utils} /** @@ -155,17 +157,18 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) - val listenerBus = new LiveListenerBus(sc) + val listenerBus = new LiveListenerBus(conf) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start() + listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) listenerBus.addListener(eventLogger) listenerBus.postToAll(applicationStart) listenerBus.postToAll(applicationEnd) + listenerBus.stop() eventLogger.stop() // Verify file contains exactly the two events logged diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 80c7e0bfee6ef..f3d0bc19675fc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,10 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.JavaConverters._ +import org.mockito.Mockito import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers @@ -36,14 +39,17 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext]) + private val mockMetricsSystem: MetricsSystem = Mockito.mock(classOf[MetricsSystem]) + test("don't call sc.stop in listener") { sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(sc.conf) bus.addListener(listener) // Starting listener bus should flush all buffered events - bus.start() + bus.start(sc, sc.env.metricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -52,35 +58,54 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("basic creation and shutdown of LiveListenerBus") { - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val conf = new SparkConf() val counter = new BasicJobCounter - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(conf) bus.addListener(counter) - // Listener bus hasn't started yet, so posting events should not increment counter + // Metrics are initially empty. + assert(bus.metrics.numEventsPosted.getCount === 0) + assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.eventProcessingTime.getCount === 0) + + // Post five events: (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + + // Five messages should be marked as received and queued, but no messages should be posted to + // listeners yet because the the listener bus hasn't been started. + assert(bus.metrics.numEventsPosted.getCount === 5) + assert(bus.metrics.queueSize.getValue === 5) assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) + Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.eventProcessingTime.getCount === 5) // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) + assert(bus.metrics.numEventsPosted.getCount === 5) + + // Make sure per-listener-class timers were created: + assert(bus.metrics.getTimerForListenerClass( + classOf[BasicJobCounter].asSubclass(classOf[SparkListenerInterface])).get.getCount == 5) // Listener bus must not be started twice intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) - bus.start() - bus.start() + val bus = new LiveListenerBus(conf) + bus.start(mockSparkContext, mockMetricsSystem) + bus.start(mockSparkContext, mockMetricsSystem) } // ... or stopped before starting intercept[IllegalStateException] { - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(conf) bus.stop() } } @@ -107,12 +132,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match drained = true } } - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) val blockingListener = new BlockingListener bus.addListener(blockingListener) - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -138,6 +162,44 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(drained) } + test("metrics for dropped listener events") { + val bus = new LiveListenerBus(new SparkConf().set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 1)) + + val listenerStarted = new Semaphore(0) + val listenerWait = new Semaphore(0) + + bus.addListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + listenerStarted.release() + listenerWait.acquire() + } + }) + + bus.start(mockSparkContext, mockMetricsSystem) + + // Post a message to the listener bus and wait for processing to begin: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + listenerStarted.acquire() + assert(bus.metrics.queueSize.getValue === 0) + assert(bus.metrics.numDroppedEvents.getCount === 0) + + // If we post an additional message then it should remain in the queue because the listener is + // busy processing the first event: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(bus.metrics.queueSize.getValue === 1) + assert(bus.metrics.numDroppedEvents.getCount === 0) + + // The queue is now full, so any additional events posted to the listener will be dropped: + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + assert(bus.metrics.queueSize.getValue === 1) + assert(bus.metrics.numDroppedEvents.getCount === 1) + + + // Allow the the remaining events to be processed so we can stop the listener bus: + listenerWait.release(2) + bus.stop() + } + test("basic creation of StageInfo") { sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo @@ -354,14 +416,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val badListener = new BadListener val jobCounter1 = new BasicJobCounter val jobCounter2 = new BasicJobCounter - sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) - val bus = new LiveListenerBus(sc) + val bus = new LiveListenerBus(new SparkConf()) // Propagate events to bad listener first bus.addListener(badListener) bus.addListener(jobCounter1) bus.addListener(jobCounter2) - bus.start() + bus.start(mockSparkContext, mockMetricsSystem) // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c100803279eaf..dd61dcd11bcda 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -100,7 +100,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0d2912ba8c5fb..9d52b488b223e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -125,7 +125,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index f6c8418ba3ac4..66dda382eb653 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ @@ -43,8 +43,7 @@ class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAn before { val conf = new SparkConf() - sc = new SparkContext("local", "test", conf) - bus = new LiveListenerBus(sc) + bus = new LiveListenerBus(conf) storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 3c4a2716caf90..fe65353b9d502 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -50,7 +50,6 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) extends SparkFunSuite with BeforeAndAfter with Matchers - with LocalSparkContext with Logging { import WriteAheadLogBasedBlockHandler._ @@ -89,10 +88,9 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - sc = new SparkContext("local", "test", conf) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(sc))), conf, true) + new LiveListenerBus(conf))), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) From 5a3371883acf8ac8f94a71cbffa75166605c91bc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 9 Jun 2017 08:53:18 +0100 Subject: [PATCH 098/133] [SPARK-14408][CORE] Changed RDD.treeAggregate to use fold instead of reduce ## What changes were proposed in this pull request? Previously, `RDD.treeAggregate` used `reduceByKey` and `reduce` in its implementation, neither of which technically allows the `seq`/`combOps` to modify and return their first arguments. This PR uses `foldByKey` and `fold` instead and notes that `aggregate` and `treeAggregate` are semantically identical in the Scala doc. Note that this had some test failures by unknown reasons. This was actually fixed in https://github.com/apache/spark/commit/e3554605b36bdce63ac180cc66dbdee5c1528ec7. The root cause was, the `zeroValue` now becomes `AFTAggregator` and it compares `totalCnt` (where the value is actually 0). It starts merging one by one and it keeps returning `this` where `totalCnt` is 0. So, this looks not the bug in the current change. This is now fixed in the commit. So, this should pass the tests. ## How was this patch tested? Test case added in `RDDSuite`. Closes #12217 Author: Joseph K. Bradley Author: hyukjinkwon Closes #18198 from HyukjinKwon/SPARK-14408. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 9 +++--- .../scala/org/apache/spark/rdd/RDDSuite.scala | 31 ++++++++++++++++++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d85..2985c90119468 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1118,9 +1118,9 @@ abstract class RDD[T: ClassTag]( /** * Aggregates the elements of this RDD in a multi-level tree pattern. + * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]]. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] */ def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, @@ -1134,7 +1134,7 @@ abstract class RDD[T: ClassTag]( val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce @@ -1146,9 +1146,10 @@ abstract class RDD[T: ClassTag]( val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values } - partiallyAggregated.reduce(cleanCombOp) + val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + partiallyAggregated.fold(copiedZeroValue)(cleanCombOp) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8d06f5468f4f1..386c0060f9c41 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -192,6 +192,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(ser.serialize(union.partitions.head).limit() < 2000) } + test("fold") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def op: (Int, Int) => Int = (c: Int, x: Int) => c + x + val sum = rdd.fold(0)(op) + assert(sum === -1000) + } + + test("fold with op modifying first arg") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + val sum = rdd.fold(Array(0))(op) + assert(sum(0) === -1000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] @@ -218,7 +235,19 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) + assert(sum === -1000) + } + } + + test("treeAggregate with ops modifying first args") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(Array(0))(op, op, depth) + assert(sum(0) === -1000) } } From bdcd6e4c680ebd3ddf5c1baaeba31134b143dfb4 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 9 Jun 2017 09:26:30 +0100 Subject: [PATCH 099/133] [SPARK-20995][CORE] Spark-env.sh.template' should add 'YARN_CONF_DIR' configuration instructions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. These configs are used to write to HDFS and connect to the YARN ResourceManager. The configuration contained in this directory will be distributed to the YARN cluster so that all containers used by the application use the same configuration. Sometimes, `HADOOP_CONF_DIR` is set to the hdfs configuration file path. So, YARN_CONF_DIR should be set to the yarn configuration file path. My project configuration item of 'spark-env.sh ' is as follows: ![1](https://cloud.githubusercontent.com/assets/26266482/26819987/d4acb814-4ad3-11e7-8458-a21aea57a53d.png) 'HADOOP_CONF_DIR' configuration file path. List the relevant documents below: ![3](https://cloud.githubusercontent.com/assets/26266482/26820116/47b6b9fe-4ad4-11e7-8131-fe07c8d8bc21.png) 'YARN_CONF_DIR' configuration file path. List the relevant documents below: ![2](https://cloud.githubusercontent.com/assets/26266482/26820078/274ad79a-4ad4-11e7-83d4-ff359dbb397c.png) So, 'Spark-env.sh.template' should add 'YARN_CONF_DIR' configuration instructions. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Author: 郭小龙 10207633 Author: guoxiaolongzte Closes #18212 from guoxiaolongzte/SPARK-20995. --- conf/spark-env.sh.template | 1 + 1 file changed, 1 insertion(+) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index b7c985ace69cf..b9aab5a3712c4 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,6 +34,7 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files +# - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) From 033839559eab280760e3c5687940d8c62cc9c048 Mon Sep 17 00:00:00 2001 From: Corey Woodfield Date: Fri, 9 Jun 2017 10:24:49 +0100 Subject: [PATCH 100/133] Fixed broken link ## What changes were proposed in this pull request? I fixed some incorrect formatting on a link in the docs ## How was this patch tested? I looked at the markdown preview before and after, and the link was fixed Before: screen shot 2017-06-08 at 6 37 32 pm After: screen shot 2017-06-08 at 6 37 44 pm Author: Corey Woodfield Closes #18246 from coreywoodfield/master. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index c1344ad99a7d2..8745e76d127ae 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -156,7 +156,7 @@ passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `Mesos If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. +For more information about these configurations please refer to the configurations [doc](configurations.html#deploy). From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the From 6491cbf065254e28bca61c9ef55b84f4009ac36c Mon Sep 17 00:00:00 2001 From: junzhi lu <452756565@qq.com> Date: Fri, 9 Jun 2017 10:49:04 +0100 Subject: [PATCH 101/133] Fix bug in JavaRegressionMetricsExample. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit the original code cant visit the last element of the"parts" array. so the v[v.length–1] always equals 0 ## What changes were proposed in this pull request? change the recycle range from (1 to parts.length-1) to (1 to parts.length) ## How was this patch tested? debug it in eclipse (´〜`*) zzz. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: junzhi lu <452756565@qq.com> Closes #18237 from masterwugui/patch-1. --- .../spark/examples/mllib/JavaRegressionMetricsExample.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index 7bb9993b84168..00033b5730a3d 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -40,7 +40,7 @@ public static void main(String[] args) { JavaRDD parsedData = data.map(line -> { String[] parts = line.split(" "); double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) { + for (int i = 1; i < parts.length; i++) { v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); From 82faacd791d1d62bb8ac186a2a3290e160a20bd5 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 9 Jun 2017 14:26:54 +0100 Subject: [PATCH 102/133] [SPARK-20997][CORE] driver-cores' standalone or Mesos or YARN in Cluster deploy mode only. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? '--driver-cores' standalone or Mesos or YARN in Cluster deploy mode only.So The description of spark-submit about it is not very accurate. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Author: 郭小龙 10207633 Author: guoxiaolongzte Closes #18241 from guoxiaolongzte/SPARK-20997. --- .../org/apache/spark/deploy/SparkSubmitArguments.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b76a3d2bea4c7..3d9a14c51618b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -558,8 +558,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --verbose, -v Print additional debug output. | --version, Print the version of current Spark. | - | Spark standalone with cluster deploy mode only: - | --driver-cores NUM Cores for driver (Default: 1). + | Cluster deploy mode only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. @@ -574,8 +575,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | or all available cores on the worker in standalone mode) | | YARN-only: - | --driver-cores NUM Number of cores used by the driver, only in cluster mode - | (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | If dynamic allocation is enabled, the initial number of From 571635488d6e16eee82f09ae0247c2f6ad5b7541 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 9 Jun 2017 10:16:30 -0700 Subject: [PATCH 103/133] [SPARK-20918][SQL] Use FunctionIdentifier as function identifiers in FunctionRegistry ### What changes were proposed in this pull request? Currently, the unquoted string of a function identifier is being used as the function identifier in the function registry. This could cause the incorrect the behavior when users use `.` in the function names. This PR is to take the `FunctionIdentifier` as the identifier in the function registry. - Add one new function `createOrReplaceTempFunction` to `FunctionRegistry` ```Scala final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit ``` ### How was this patch tested? Add extra test cases to verify the inclusive bug fixes. Author: Xiao Li Author: gatorsmile Closes #18142 from gatorsmile/fuctionRegistry. --- .../catalyst/analysis/FunctionRegistry.scala | 97 +++++++++++------ .../sql/catalyst/catalog/SessionCatalog.scala | 37 ++++--- .../catalog/SessionCatalogSuite.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 100 +++++++++--------- .../sql/execution/command/functions.scala | 2 +- .../spark/sql/GeneratorFunctionSuite.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/SessionStateSuite.scala | 9 +- .../python/BatchEvalPythonExecSuite.scala | 5 +- .../spark/sql/internal/CatalogSuite.scala | 4 + .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 13 ++- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 13 files changed, 162 insertions(+), 118 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 116b26f612e02..4245b70892d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,51 +17,68 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier +import java.util.Locale +import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ -import org.apache.spark.sql.catalyst.util.StringKeyHashMap import org.apache.spark.sql.types._ /** * A catalog for looking up user defined functions, used by an [[Analyzer]]. * - * Note: The implementation should be thread-safe to allow concurrent access. + * Note: + * 1) The implementation should be thread-safe to allow concurrent access. + * 2) the database name is always case-sensitive here, callers are responsible to + * format the database name w.r.t. case-sensitive config. */ trait FunctionRegistry { - final def registerFunction(name: String, builder: FunctionBuilder): Unit = { - registerFunction(name, new ExpressionInfo(builder.getClass.getCanonicalName, name), builder) + final def registerFunction(name: FunctionIdentifier, builder: FunctionBuilder): Unit = { + val info = new ExpressionInfo( + builder.getClass.getCanonicalName, name.database.orNull, name.funcName) + registerFunction(name, info, builder) } - def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit + def registerFunction( + name: FunctionIdentifier, + info: ExpressionInfo, + builder: FunctionBuilder): Unit + + /* Create or replace a temporary function. */ + final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit = { + registerFunction( + FunctionIdentifier(name), + builder) + } @throws[AnalysisException]("If function does not exist") - def lookupFunction(name: String, children: Seq[Expression]): Expression + def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression /* List all of the registered function names. */ - def listFunction(): Seq[String] + def listFunction(): Seq[FunctionIdentifier] /* Get the class of the registered function by specified name. */ - def lookupFunction(name: String): Option[ExpressionInfo] + def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] /* Get the builder of the registered function by specified name. */ - def lookupFunctionBuilder(name: String): Option[FunctionBuilder] + def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] /** Drop a function and return whether the function existed. */ - def dropFunction(name: String): Boolean + def dropFunction(name: FunctionIdentifier): Boolean /** Checks if a function with a given name exists. */ - def functionExists(name: String): Boolean = lookupFunction(name).isDefined + def functionExists(name: FunctionIdentifier): Boolean = lookupFunction(name).isDefined /** Clear all registered functions. */ def clear(): Unit @@ -72,39 +89,47 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - protected val functionBuilders = - StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) + @GuardedBy("this") + private val functionBuilders = + new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)] + + // Resolution of the function name is always case insensitive, but the database name + // depends on the caller + private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database) + } override def registerFunction( - name: String, + name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = synchronized { - functionBuilders.put(name, (info, builder)) + functionBuilders.put(normalizeFuncName(name), (info, builder)) } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { val func = synchronized { - functionBuilders.get(name).map(_._2).getOrElse { + functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse { throw new AnalysisException(s"undefined function $name") } } func(children) } - override def listFunction(): Seq[String] = synchronized { - functionBuilders.iterator.map(_._1).toList.sorted + override def listFunction(): Seq[FunctionIdentifier] = synchronized { + functionBuilders.iterator.map(_._1).toList } - override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized { - functionBuilders.get(name).map(_._1) + override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized { + functionBuilders.get(normalizeFuncName(name)).map(_._1) } - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized { - functionBuilders.get(name).map(_._2) + override def lookupFunctionBuilder( + name: FunctionIdentifier): Option[FunctionBuilder] = synchronized { + functionBuilders.get(normalizeFuncName(name)).map(_._2) } - override def dropFunction(name: String): Boolean = synchronized { - functionBuilders.remove(name).isDefined + override def dropFunction(name: FunctionIdentifier): Boolean = synchronized { + functionBuilders.remove(normalizeFuncName(name)).isDefined } override def clear(): Unit = synchronized { @@ -125,28 +150,28 @@ class SimpleFunctionRegistry extends FunctionRegistry { * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = { + override def registerFunction( + name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = { throw new UnsupportedOperationException } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - override def listFunction(): Seq[String] = { + override def listFunction(): Seq[FunctionIdentifier] = { throw new UnsupportedOperationException } - override def lookupFunction(name: String): Option[ExpressionInfo] = { + override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = { throw new UnsupportedOperationException } - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { + override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = { throw new UnsupportedOperationException } - override def dropFunction(name: String): Boolean = { + override def dropFunction(name: FunctionIdentifier): Boolean = { throw new UnsupportedOperationException } @@ -455,11 +480,13 @@ object FunctionRegistry { val builtin: SimpleFunctionRegistry = { val fr = new SimpleFunctionRegistry - expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } + expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) + } fr } - val functionSet: Set[String] = builtin.listFunction().toSet + val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet /** See usage above. */ private def expression[T <: Expression](name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 57006bfaf9b69..b6744a7f53a54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1029,13 +1029,12 @@ class SessionCatalog( requireDbExists(db) val identifier = name.copy(database = Some(db)) if (functionExists(identifier)) { - // TODO: registry should just take in FunctionIdentifier for type safety - if (functionRegistry.functionExists(identifier.unquotedString)) { + if (functionRegistry.functionExists(identifier)) { // If we have loaded this function into the FunctionRegistry, // also drop it from there. // For a permanent function, because we loaded it to the FunctionRegistry // when it's first used, we also need to drop it from the FunctionRegistry. - functionRegistry.dropFunction(identifier.unquotedString) + functionRegistry.dropFunction(identifier) } externalCatalog.dropFunction(db, name.funcName) } else if (!ignoreIfNotExists) { @@ -1061,7 +1060,7 @@ class SessionCatalog( def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) requireDbExists(db) - functionRegistry.functionExists(name.unquotedString) || + functionRegistry.functionExists(name) || externalCatalog.functionExists(db, name.funcName) } @@ -1095,20 +1094,20 @@ class SessionCatalog( ignoreIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier - if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + if (functionRegistry.functionExists(func) && !ignoreIfExists) { throw new AnalysisException(s"Function $func already exists") } val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) val builder = functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) - functionRegistry.registerFunction(func.unquotedString, info, builder) + functionRegistry.registerFunction(func, info, builder) } /** * Drop a temporary function. */ def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && !ignoreIfNotExists) { throw new NoSuchTempFunctionException(name) } } @@ -1123,8 +1122,8 @@ class SessionCatalog( // A temporary function is a function that has been registered in functionRegistry // without a database name, and is neither a built-in function nor a Hive function name.database.isEmpty && - functionRegistry.functionExists(name.funcName) && - !FunctionRegistry.builtin.functionExists(name.funcName) && + functionRegistry.functionExists(name) && + !FunctionRegistry.builtin.functionExists(name) && !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } @@ -1140,8 +1139,8 @@ class SessionCatalog( // TODO: just make function registry take in FunctionIdentifier instead of duplicating this val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) val qualifiedName = name.copy(database = database) - functionRegistry.lookupFunction(name.funcName) - .orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString)) + functionRegistry.lookupFunction(name) + .orElse(functionRegistry.lookupFunction(qualifiedName)) .getOrElse { val db = qualifiedName.database.get requireDbExists(db) @@ -1176,19 +1175,19 @@ class SessionCatalog( // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). - if (name.database.isEmpty && functionRegistry.functionExists(name.funcName)) { + if (name.database.isEmpty && functionRegistry.functionExists(name)) { // This function has been already loaded into the function registry. - return functionRegistry.lookupFunction(name.funcName, children) + return functionRegistry.lookupFunction(name, children) } // If the name itself is not qualified, add the current database to it. val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val qualifiedName = name.copy(database = Some(database)) - if (functionRegistry.functionExists(qualifiedName.unquotedString)) { + if (functionRegistry.functionExists(qualifiedName)) { // This function has been already loaded into the function registry. // Unlike the above block, we find this function by using the qualified name. - return functionRegistry.lookupFunction(qualifiedName.unquotedString, children) + return functionRegistry.lookupFunction(qualifiedName, children) } // The function has not been loaded to the function registry, which means @@ -1209,7 +1208,7 @@ class SessionCatalog( // At here, we preserve the input from the user. registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. - functionRegistry.lookupFunction(qualifiedName.unquotedString, children) + functionRegistry.lookupFunction(qualifiedName, children) } /** @@ -1229,8 +1228,8 @@ class SessionCatalog( requireDbExists(dbName) val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = - StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + val loadedFunctions = StringUtils + .filterPattern(functionRegistry.listFunction().map(_.unquotedString), pattern).map { f => // In functionRegistry, function names are stored as an unquoted format. Try(parser.parseFunctionIdentifier(f)) match { case Success(e) => e @@ -1243,7 +1242,7 @@ class SessionCatalog( // The session catalog caches some persistent functions in the FunctionRegistry // so there can be duplicates. functions.map { - case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") + case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") case f => (f, "USER") }.distinct } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index be8903000a0d1..5afeb0e8ca032 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1209,7 +1209,7 @@ abstract class SessionCatalogSuite extends PlanTest { assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1"))) // Returns false when the function is built-in or hive - assert(FunctionRegistry.builtin.functionExists("sum")) + assert(FunctionRegistry.builtin.functionExists(FunctionIdentifier("sum"))) assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum"))) assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1bceac41b9de7..ad01b889429c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -61,7 +61,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | dataType: ${udf.dataType} """.stripMargin) - functionRegistry.registerFunction(name, udf.builder) + functionRegistry.createOrReplaceTempFunction(name, udf.builder) } /** @@ -75,7 +75,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) udaf } @@ -91,7 +91,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) udf } @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") } @@ -130,7 +130,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) - | functionRegistry.registerFunction( + | functionRegistry.createOrReplaceTempFunction( | name, | (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) |}""".stripMargin) @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -419,7 +419,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -432,7 +432,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) - functionRegistry.registerFunction(name, builder) + functionRegistry.createOrReplaceTempFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } @@ -510,7 +510,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -521,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -532,7 +532,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -543,7 +543,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -554,7 +554,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -576,7 +576,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -587,7 +587,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -598,7 +598,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -609,7 +609,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -620,7 +620,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -631,7 +631,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -642,7 +642,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -653,7 +653,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -664,7 +664,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -675,7 +675,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -686,7 +686,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -697,7 +697,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -708,7 +708,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -730,7 +730,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } @@ -741,7 +741,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - functionRegistry.registerFunction( + functionRegistry.createOrReplaceTempFunction( name, (e: Seq[Expression]) => ScalaUDF(func, returnType, e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 545082324f0d3..f39a3269efaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -160,7 +160,7 @@ case class DropFunctionCommand( throw new AnalysisException(s"Specifying a database in DROP TEMPORARY FUNCTION " + s"is not allowed: '${databaseName.get}'") } - if (FunctionRegistry.builtin.functionExists(functionName)) { + if (FunctionRegistry.builtin.functionExists(FunctionIdentifier(functionName))) { throw new AnalysisException(s"Cannot drop native function '$functionName'") } catalog.dropTempFunction(functionName, ifExists) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index b9871afd59e4f..539c63d3cb288 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -297,7 +297,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { } test("outer generator()") { - spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator()) + spark.sessionState.functionRegistry + .createOrReplaceTempFunction("empty_gen", _ => EmptyGenerator()) checkAnswer( sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), Row(1, null) :: Row(2, null) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 41e9e2c92ca8e..a7efcafa0166a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -109,7 +109,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5638c8eeda842..c01666770720c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution @@ -71,10 +72,10 @@ class SessionStateSuite extends SparkFunSuite } test("fork new session and inherit function registry and udf") { - val testFuncName1 = "strlenScala" - val testFuncName2 = "addone" + val testFuncName1 = FunctionIdentifier("strlenScala") + val testFuncName2 = FunctionIdentifier("addone") try { - activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + activeSession.udf.register(testFuncName1.funcName, (_: String).length + (_: Int)) val forkedSession = activeSession.cloneSession() // inheritance @@ -86,7 +87,7 @@ class SessionStateSuite extends SparkFunSuite // independence forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) - activeSession.udf.register(testFuncName2, (_: Int) + 1) + activeSession.udf.register(testFuncName2.funcName, (_: Int) + 1) assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) } finally { activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 2a3d1cf0b298a..80ef4eb75ca53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,7 +21,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -36,7 +37,7 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { } override def afterAll(): Unit = { - spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF")) super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index bc641fd280a15..b2d568ce320e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -367,6 +367,7 @@ class CatalogSuite withUserDefinedFunction("fn1" -> true, s"$db.fn2" -> false) { // Try to find non existing functions. intercept[AnalysisException](spark.catalog.getFunction("fn1")) + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) intercept[AnalysisException](spark.catalog.getFunction("fn2")) intercept[AnalysisException](spark.catalog.getFunction(db, "fn2")) @@ -379,6 +380,8 @@ class CatalogSuite assert(fn1.name === "fn1") assert(fn1.database === null) assert(fn1.isTemporary) + // Find a temporary function with database + intercept[AnalysisException](spark.catalog.getFunction(db, "fn1")) // Find a qualified function val fn2 = spark.catalog.getFunction(db, "fn2") @@ -455,6 +458,7 @@ class CatalogSuite // Find a temporary function assert(spark.catalog.functionExists("fn1")) + assert(!spark.catalog.functionExists(db, "fn1")) // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 6227e780c0409..da87f0218e3ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -129,7 +129,7 @@ private[sql] class HiveSessionCatalog( Try(super.lookupFunction(funcName, children)) match { case Success(expr) => expr case Failure(error) => - if (functionRegistry.functionExists(funcName.unquotedString)) { + if (functionRegistry.functionExists(funcName)) { // If the function actually exists in functionRegistry, it means that there is an // error when we create the Expression using the given children. // We need to throw the original exception. @@ -163,7 +163,7 @@ private[sql] class HiveSessionCatalog( // Put this Hive built-in function to our function registry. registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. - functionRegistry.lookupFunction(functionName, children) + functionRegistry.lookupFunction(functionIdentifier, children) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 8fcbad58350f4..cae338c0ab0ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -194,7 +194,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(sql("SELECT percentile_approx(100.0D, array(0.9D, 0.9D)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) - } + } test("UDFIntegerToString") { val testData = spark.sparkContext.parallelize( @@ -592,6 +592,17 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } + test("Temp function has dots in the names") { + withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) { + sql(s"CREATE FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + // temp function containing dots in the name + spark.udf.register("default.test_avg", () => { Math.random() + 2}) + assert(sql("SELECT `default.test_avg`()").head().getDouble(0) >= 2.0) + checkAnswer(sql("SELECT test_avg(1)"), Row(1.0)) + } + } + test("Call the function registered in the not-current database") { Seq("true", "false").foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index da7a0645dbbeb..a949e5e829e14 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().map(_.unquotedString) val allFunctions = sql("SHOW functions").collect().map(r => r(0)) allBuiltinFunctions.foreach { f => assert(allFunctions.contains(f)) From b78e3849b20d0d09b7146efd7ce8f203ef67b890 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 9 Jun 2017 18:29:33 -0700 Subject: [PATCH 104/133] [SPARK-21042][SQL] Document Dataset.union is resolution by position ## What changes were proposed in this pull request? Document Dataset.union is resolution by position, not by name, since this has been a confusing point for a lot of users. ## How was this patch tested? N/A - doc only change. Author: Reynold Xin Closes #18256 from rxin/SPARK-21042. --- R/pkg/R/DataFrame.R | 1 + python/pyspark/sql/dataframe.py | 13 +++++++++---- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 ++++++++------ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 166b39813c14e..3b9d42d6e7158 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2646,6 +2646,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. +#' Also as standard in SQL, this function resolves columns by position (not by name). #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 99abfcc556dff..8541403dfe2f1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1175,18 +1175,23 @@ def agg(self, *exprs): @since(2.0) def union(self, other): - """ Return a new :class:`DataFrame` containing union of rows in this - frame and another frame. + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by a distinct. + + Also as standard in SQL, this function resolves columns by position (not by name). """ return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) @since(1.3) def unionAll(self, other): - """ Return a new :class:`DataFrame` containing union of rows in this - frame and another frame. + """ Return a new :class:`DataFrame` containing union of rows in this and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by a distinct. + + Also as standard in SQL, this function resolves columns by position (not by name). .. note:: Deprecated in 2.0, use union instead. """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f7637e005f317..d28ff7888d127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1734,10 +1734,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 @@ -1747,10 +1748,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 From 8e96acf71c7bf1686aeca842f626f66c1cc8117f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 10 Jun 2017 10:28:14 -0700 Subject: [PATCH 105/133] [SPARK-20211][SQL] Fix the Precision and Scale of Decimal Values when the Input is BigDecimal between -1.0 and 1.0 ### What changes were proposed in this pull request? The precision and scale of decimal values are wrong when the input is BigDecimal between -1.0 and 1.0. The BigDecimal's precision is the digit count starts from the leftmost nonzero digit based on the [JAVA's BigDecimal definition](https://docs.oracle.com/javase/7/docs/api/java/math/BigDecimal.html). However, our Decimal decision follows the database decimal standard, which is the total number of digits, including both to the left and the right of the decimal point. Thus, this PR is to fix the issue by doing the conversion. Before this PR, the following queries failed: ```SQL select 1 > 0.0001 select floor(0.0001) select ceil(0.0001) ``` ### How was this patch tested? Added test cases. Author: Xiao Li Closes #18244 from gatorsmile/bigdecimal. --- .../org/apache/spark/sql/types/Decimal.scala | 10 +++- .../apache/spark/sql/types/DecimalSuite.scala | 10 ++++ .../resources/sql-tests/inputs/operators.sql | 7 +++ .../sql-tests/results/operators.sql.out | 58 ++++++++++++++++--- 4 files changed, 75 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 7e3b86b76a34a..75a0256ad7239 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -65,8 +65,15 @@ select ceiling(0); select ceiling(1); select ceil(1234567890123456); select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); -- floor select floor(0); select floor(1); select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 28cfb744193ec..57e8a612fab44 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 45 +-- Number of queries: 50 -- !query 0 @@ -351,24 +351,64 @@ struct -- !query 42 -select floor(0) +select ceil(0.01) -- !query 42 schema -struct +struct -- !query 42 output -0 +1 -- !query 43 -select floor(1) +select ceiling(-0.10) -- !query 43 schema -struct +struct -- !query 43 output -1 +0 -- !query 44 -select floor(1234567890123456) +select floor(0) -- !query 44 schema -struct +struct -- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output 1234567890123456 + + +-- !query 47 +select floor(0.01) +-- !query 47 schema +struct +-- !query 47 output +0 + + +-- !query 48 +select floor(-0.10) +-- !query 48 schema +struct +-- !query 48 output +-1 + + +-- !query 49 +select 1 > 0.00001 +-- !query 49 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 49 output +true From 5301a19a0e2df2d2b1a5cb2d44c595423df78cf7 Mon Sep 17 00:00:00 2001 From: liuxian Date: Sat, 10 Jun 2017 10:42:23 -0700 Subject: [PATCH 106/133] [SPARK-20620][TEST] Improve some unit tests for NullExpressionsSuite and TypeCoercionSuite ## What changes were proposed in this pull request? add more datatype for some unit tests ## How was this patch tested? unit tests Author: liuxian Closes #17880 from 10110346/wip_lx_0506. --- .../catalyst/analysis/TypeCoercionSuite.scala | 98 ++++++++++++++----- .../expressions/NullExpressionsSuite.scala | 18 +++- 2 files changed, 93 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d..2ac11598e63d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest { } test("coalesce casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1.0) - :: Literal(1) - :: Literal.create(1.0, FloatType) - :: Nil), - Coalesce(Cast(Literal(1.0), DoubleType) - :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) - :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1L) - :: Literal(1) - :: Literal(new java.math.BigDecimal("1000000000000000000000")) - :: Nil), - Coalesce(Cast(Literal(1L), DecimalType(22, 0)) - :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) - :: Nil)) + val rule = TypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(Cast(doubleLit, DoubleType), + Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), + Cast(intLit, FloatType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), + Cast(doubleLit, StringType), Cast(stringLit, StringType)))) } test("CreateArray casts") { @@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = TypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest { If(Literal.create(null, BooleanType), Literal(1), Literal(1))) ruleTest(rule, - If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, Cast(doubleLit, StringType))) ruleTest(rule, - If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(trueLit, timestampLit, stringLit), + If(trueLit, Cast(timestampLit, StringType), stringLit)) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 5064a1f63f83d..394c0a091e390 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.01f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) + + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") { From dc4c351837879dab26ad8fb471dc51c06832a9e4 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 11 Jun 2017 00:00:33 -0700 Subject: [PATCH 107/133] [SPARK-20877][SPARKR] refactor tests to basic tests only for CRAN ## What changes were proposed in this pull request? Move all existing tests to non-installed directory so that it will never run by installing SparkR package For a follow-up PR: - remove all skip_on_cran() calls in tests - clean up test timer - improve or change basic tests that do run on CRAN (if anyone has suggestion) It looks like `R CMD build pkg` will still put pkg\tests (ie. the full tests) into the source package but `R CMD INSTALL` on such source package does not install these tests (and so `R CMD check` does not run them) ## How was this patch tested? - [x] unit tests, Jenkins - [x] AppVeyor - [x] make a source package, install it, `R CMD check` it - verify the full tests are not installed or run Author: Felix Cheung Closes #18264 from felixcheung/rtestset. --- R/pkg/inst/tests/testthat/test_basic.R | 90 +++++++++++++++++++ .../testthat => tests/fulltests}/jarTest.R | 0 .../fulltests}/packageInAJarTest.R | 0 .../testthat => tests/fulltests}/test_Serde.R | 0 .../fulltests}/test_Windows.R | 0 .../fulltests}/test_binaryFile.R | 0 .../fulltests}/test_binary_function.R | 0 .../fulltests}/test_broadcast.R | 0 .../fulltests}/test_client.R | 0 .../fulltests}/test_context.R | 0 .../fulltests}/test_includePackage.R | 0 .../fulltests}/test_jvm_api.R | 0 .../fulltests}/test_mllib_classification.R | 0 .../fulltests}/test_mllib_clustering.R | 0 .../fulltests}/test_mllib_fpm.R | 0 .../fulltests}/test_mllib_recommendation.R | 0 .../fulltests}/test_mllib_regression.R | 0 .../fulltests}/test_mllib_stat.R | 0 .../fulltests}/test_mllib_tree.R | 0 .../fulltests}/test_parallelize_collect.R | 0 .../testthat => tests/fulltests}/test_rdd.R | 0 .../fulltests}/test_shuffle.R | 0 .../fulltests}/test_sparkR.R | 0 .../fulltests}/test_sparkSQL.R | 0 .../fulltests}/test_streaming.R | 0 .../testthat => tests/fulltests}/test_take.R | 0 .../fulltests}/test_textFile.R | 0 .../testthat => tests/fulltests}/test_utils.R | 0 R/pkg/tests/run-all.R | 8 ++ 29 files changed, 98 insertions(+) create mode 100644 R/pkg/inst/tests/testthat/test_basic.R rename R/pkg/{inst/tests/testthat => tests/fulltests}/jarTest.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/packageInAJarTest.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_Serde.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_Windows.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_binaryFile.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_binary_function.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_broadcast.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_client.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_context.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_includePackage.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_jvm_api.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_classification.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_clustering.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_fpm.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_recommendation.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_regression.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_stat.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_mllib_tree.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_parallelize_collect.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_rdd.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_shuffle.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_sparkR.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_sparkSQL.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_streaming.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_take.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_textFile.R (100%) rename R/pkg/{inst/tests/testthat => tests/fulltests}/test_utils.R (100%) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R new file mode 100644 index 0000000000000..de47162d5325f --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic tests for CRAN") + +test_that("create DataFrame from list or data.frame", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + i <- 4 + df <- createDataFrame(data.frame(dummy = 1:i)) + expect_equal(count(df), i) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + mtcarsdf <- createDataFrame(mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) + + sparkR.session.stop() +}) + +test_that("spark.glm and predict", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + sparkR.session.stop() +}) diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/jarTest.R rename to R/pkg/tests/fulltests/jarTest.R diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/testthat/packageInAJarTest.R rename to R/pkg/tests/fulltests/packageInAJarTest.R diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_Serde.R rename to R/pkg/tests/fulltests/test_Serde.R diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_Windows.R rename to R/pkg/tests/fulltests/test_Windows.R diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_binaryFile.R rename to R/pkg/tests/fulltests/test_binaryFile.R diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_binary_function.R rename to R/pkg/tests/fulltests/test_binary_function.R diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_broadcast.R rename to R/pkg/tests/fulltests/test_broadcast.R diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_client.R rename to R/pkg/tests/fulltests/test_client.R diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_context.R rename to R/pkg/tests/fulltests/test_context.R diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_includePackage.R rename to R/pkg/tests/fulltests/test_includePackage.R diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_jvm_api.R rename to R/pkg/tests/fulltests/test_jvm_api.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_classification.R rename to R/pkg/tests/fulltests/test_mllib_classification.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R rename to R/pkg/tests/fulltests/test_mllib_clustering.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R rename to R/pkg/tests/fulltests/test_mllib_fpm.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R rename to R/pkg/tests/fulltests/test_mllib_recommendation.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_regression.R rename to R/pkg/tests/fulltests/test_mllib_regression.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_stat.R rename to R/pkg/tests/fulltests/test_mllib_stat.R diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_mllib_tree.R rename to R/pkg/tests/fulltests/test_mllib_tree.R diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R rename to R/pkg/tests/fulltests/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_rdd.R rename to R/pkg/tests/fulltests/test_rdd.R diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_shuffle.R rename to R/pkg/tests/fulltests/test_shuffle.R diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_sparkR.R rename to R/pkg/tests/fulltests/test_sparkR.R diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_sparkSQL.R rename to R/pkg/tests/fulltests/test_sparkSQL.R diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_streaming.R rename to R/pkg/tests/fulltests/test_streaming.R diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_take.R rename to R/pkg/tests/fulltests/test_take.R diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_textFile.R rename to R/pkg/tests/fulltests/test_textFile.R diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_utils.R rename to R/pkg/tests/fulltests/test_utils.R diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index f0bef4f6d2662..d48e36c880c13 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -43,3 +43,11 @@ if (identical(Sys.getenv("NOT_CRAN"), "true")) { } test_package("SparkR") + +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") +} From 8da3f7041aafa71d7596b531625edb899970fec2 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Sun, 11 Jun 2017 09:49:39 +0100 Subject: [PATCH 108/133] [SPARK-21000][MESOS] Add Mesos labels support to the Spark Dispatcher ## What changes were proposed in this pull request? Add Mesos labels support to the Spark Dispatcher ## How was this patch tested? unit tests Author: Michael Gummelt Closes #18220 from mgummelt/SPARK-21000-dispatcher-labels. --- docs/running-on-mesos.md | 14 ++++- .../apache/spark/deploy/mesos/config.scala | 7 +++ .../cluster/mesos/MesosClusterScheduler.scala | 10 ++-- .../MesosCoarseGrainedSchedulerBackend.scala | 28 ++-------- .../cluster/mesos/MesosProtoUtils.scala | 53 +++++++++++++++++++ .../mesos/MesosClusterSchedulerSuite.scala | 27 ++++++++++ ...osCoarseGrainedSchedulerBackendSuite.scala | 23 -------- .../cluster/mesos/MesosProtoUtilsSuite.scala | 48 +++++++++++++++++ 8 files changed, 157 insertions(+), 53 deletions(-) create mode 100644 resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala create mode 100644 resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8745e76d127ae..ec130c1db8f5f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -382,8 +382,9 @@ See the [configuration page](configuration.html) for information on Spark config (none) Set the Mesos labels to add to each task. Labels are free-form key-value pairs. - Key-value pairs should be separated by a colon, and commas used to list more than one. - Ex. key:value,key2:value2. + Key-value pairs should be separated by a colon, and commas used to + list more than one. If your label includes a colon or comma, you + can escape it with a backslash. Ex. key:value,key2:a\:b. @@ -468,6 +469,15 @@ See the [configuration page](configuration.html) for information on Spark config If unset it will point to Spark's internal web UI. + + spark.mesos.driver.labels + (none) + + Mesos labels to add to the driver. See spark.mesos.task.labels + for formatting information. + + + spark.mesos.driverEnv.[EnvironmentVariableName] (none) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala index 19e253394f1b2..56d697f359614 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala @@ -56,4 +56,11 @@ package object config { .stringConf .createOptional + private [spark] val DRIVER_LABELS = + ConfigBuilder("spark.mesos.driver.labels") + .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value" + + "pairs should be separated by a colon, and commas used to list more than one." + + "Ex. key:value,key2:value2") + .stringConf + .createOptional } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1bc6f71860c3f..577f9a876b381 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,11 +30,13 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} +import org.apache.spark.deploy.mesos.config import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils + /** * Tracks the current state of a Mesos Task that runs a Spark driver. * @param driverDescription Submitted driver description from @@ -525,15 +527,17 @@ private[spark] class MesosClusterScheduler( offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") - val taskInfo = TaskInfo.newBuilder() + + TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) - taskInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) - taskInfo.build + .setLabels(MesosProtoUtils.mesosLabels(desc.conf.get(config.DRIVER_LABELS).getOrElse(""))) + .setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) + .build } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index ac7aec7b0a034..871685c6cccc0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -419,16 +419,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName(s"${sc.appName} $taskId") - - taskBuilder.addAllResources(resourcesToUse.asJava) - taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) - - val labelsBuilder = taskBuilder.getLabelsBuilder - val labels = buildMesosLabels().asJava - - labelsBuilder.addAllLabels(labels) - - taskBuilder.setLabels(labelsBuilder) + .setLabels(MesosProtoUtils.mesosLabels(taskLabels)) + .addAllResources(resourcesToUse.asJava) + .setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava @@ -444,21 +437,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } - private def buildMesosLabels(): List[Label] = { - taskLabels.split(",").flatMap(label => - label.split(":") match { - case Array(key, value) => - Some(Label.newBuilder() - .setKey(key) - .setValue(value) - .build()) - case _ => - logWarning(s"Unable to parse $label into a key:value label for the task.") - None - } - ).toList - } - /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala new file mode 100644 index 0000000000000..fea01c7068c9a --- /dev/null +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConverters._ + +import org.apache.mesos.Protos + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +object MesosProtoUtils extends Logging { + + /** Parses a label string of the format specified in spark.mesos.task.labels. */ + def mesosLabels(labelsStr: String): Protos.Labels.Builder = { + val labels: Seq[Protos.Label] = if (labelsStr == "") { + Seq() + } else { + labelsStr.split("""(? + val parts = labelStr.split("""(? part.replaceAll("""\\,""", ",")) + .map(part => part.replaceAll("""\\:""", ":")) + + Protos.Label.newBuilder() + .setKey(cleanedParts(0)) + .setValue(cleanedParts(1)) + .build() + } + } + + Protos.Labels.newBuilder().addAllLabels(labels.asJava) + } +} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 32967b04cd346..0bb47906347d5 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -248,6 +248,33 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(networkInfos.get(0).getName == "test-network-name") } + test("supports spark.mesos.driver.labels") { + setScheduler() + + val mem = 1000 + val cpu = 1 + + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + "spark.mesos.driver.labels" -> "key:value"), + "s1", + new Date())) + + assert(response.success) + + val offer = Utils.createOffer("o1", "s1", mem, cpu) + scheduler.resourceOffers(driver, List(offer).asJava) + + val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") + val labels = launchedTasks.head.getLabels + assert(labels.getLabelsCount == 1) + assert(labels.getLabels(0).getKey == "key") + assert(labels.getLabels(0).getValue == "value") + } + test("can kill supervised drivers") { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 0418bfbaa5ed8..7cca5fedb31eb 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -532,29 +532,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getLabels.equals(taskLabels)) } - test("mesos ignored invalid labels and sets configurable labels on tasks") { - val taskLabelsString = "mesos:test,label:test,incorrect:label:here" - setBackend(Map( - "spark.mesos.task.labels" -> taskLabelsString - )) - - // Build up the labels - val taskLabels = Protos.Labels.newBuilder() - .addLabels(Protos.Label.newBuilder() - .setKey("mesos").setValue("test").build()) - .addLabels(Protos.Label.newBuilder() - .setKey("label").setValue("test").build()) - .build() - - val offers = List(Resources(backend.executorMemory(sc), 1)) - offerResources(offers) - val launchedTasks = verifyTaskLaunched(driver, "o1") - - val labels = launchedTasks.head.getLabels - - assert(launchedTasks.head.getLabels.equals(taskLabels)) - } - test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala new file mode 100644 index 0000000000000..36a4c1ab1ad25 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosProtoUtilsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.spark.SparkFunSuite + +class MesosProtoUtilsSuite extends SparkFunSuite { + test("mesosLabels") { + val labels = MesosProtoUtils.mesosLabels("key:value") + assert(labels.getLabelsCount == 1) + val label = labels.getLabels(0) + assert(label.getKey == "key") + assert(label.getValue == "value") + + val labels2 = MesosProtoUtils.mesosLabels("key:value\\:value") + assert(labels2.getLabelsCount == 1) + val label2 = labels2.getLabels(0) + assert(label2.getKey == "key") + assert(label2.getValue == "value:value") + + val labels3 = MesosProtoUtils.mesosLabels("key:value,key2:value2") + assert(labels3.getLabelsCount == 2) + assert(labels3.getLabels(0).getKey == "key") + assert(labels3.getLabels(0).getValue == "value") + assert(labels3.getLabels(1).getKey == "key2") + assert(labels3.getLabels(1).getValue == "value2") + + val labels4 = MesosProtoUtils.mesosLabels("key:value\\,value") + assert(labels4.getLabelsCount == 1) + assert(labels4.getLabels(0).getKey == "key") + assert(labels4.getLabels(0).getValue == "value,value") + } +} From eb3ea3a0831b26d3dc35a97566716b92868a7beb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 11 Jun 2017 09:54:57 +0100 Subject: [PATCH 109/133] [SPARK-20935][STREAMING] Always close WriteAheadLog and make it idempotent ## What changes were proposed in this pull request? This PR proposes to stop `ReceiverTracker` to close `WriteAheadLog` whenever it is and make `WriteAheadLog` and its implementations idempotent. ## How was this patch tested? Added a test in `WriteAheadLogSuite`. Note that the added test looks passing even if it closes twice (namely even without the changes in `FileBasedWriteAheadLog` and `BatchedWriteAheadLog`. It looks both are already idempotent but this is a rather sanity check. Author: hyukjinkwon Closes #18224 from HyukjinKwon/streaming-closing. --- .../spark/streaming/util/WriteAheadLog.java | 2 +- .../streaming/scheduler/ReceiverTracker.scala | 27 +++++++------------ .../streaming/util/BatchedWriteAheadLog.scala | 13 +++++---- .../util/FileBasedWriteAheadLog.scala | 8 +++--- .../scheduler/ReceiverTrackerSuite.scala | 2 ++ .../streaming/util/WriteAheadLogSuite.scala | 2 ++ 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 2803cad8095dd..00c59728748f6 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -56,7 +56,7 @@ public abstract class WriteAheadLog { public abstract void clean(long threshTime, boolean waitForCompletion); /** - * Close this log and release any resources. + * Close this log and release any resources. It must be idempotent. */ public abstract void close(); } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index bd7ab0b9bf5eb..6f130c803f310 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -165,11 +165,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (isTrackerStarted) { - // First, stop the receivers - trackerState = Stopping + val isStarted: Boolean = isTrackerStarted + trackerState = Stopping + if (isStarted) { if (!skipReceiverLaunch) { - // Send the stop signal to all the receivers + // First, stop the receivers. Send the stop signal to all the receivers endpoint.askSync[Boolean](StopAllReceivers) // Wait for the Spark job that runs the receivers to be over @@ -194,17 +194,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped - } else if (isTrackerInitialized) { - trackerState = Stopping - // `ReceivedBlockTracker` is open when this instance is created. We should - // close this even if this `ReceiverTracker` is not started. - receivedBlockTracker.stop() - logInfo("ReceiverTracker stopped") - trackerState = Stopped } + + // `ReceivedBlockTracker` is open when this instance is created. We should + // close this even if this `ReceiverTracker` is not started. + receivedBlockTracker.stop() + logInfo("ReceiverTracker stopped") + trackerState = Stopped } /** Allocate all unallocated blocks to the given batch. */ @@ -453,9 +449,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false endpoint.send(StartAllReceivers(receivers)) } - /** Check if tracker has been marked for initiated */ - private def isTrackerInitialized: Boolean = trackerState == Initialized - /** Check if tracker has been marked for starting */ private def isTrackerStarted: Boolean = trackerState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 35f0166ed0cf2..e522bc62d5cac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private val walWriteQueue = new LinkedBlockingQueue[Record]() // Whether the writer thread is active - @volatile private var active: Boolean = true + private val active: AtomicBoolean = new AtomicBoolean(true) private val buffer = new ArrayBuffer[Record]() private val batchedWriterThread = startBatchedWriterThread() @@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { val promise = Promise[WriteAheadLogRecordHandle]() val putSuccessfully = synchronized { - if (active) { + if (active.get()) { walWriteQueue.offer(Record(byteBuffer, time, promise)) true } else { @@ -121,9 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp */ override def close(): Unit = { logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") - synchronized { - active = false - } + if (!active.getAndSet(false)) return batchedWriterThread.interrupt() batchedWriterThread.join() while (!walWriteQueue.isEmpty) { @@ -138,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp private def startBatchedWriterThread(): Thread = { val thread = new Thread(new Runnable { override def run(): Unit = { - while (active) { + while (active.get()) { try { flushRecords() } catch { @@ -166,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp } try { var segment: WriteAheadLogRecordHandle = null - if (buffer.length > 0) { + if (buffer.nonEmpty) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") // threads may not be able to add items in order by time val sortedByTime = buffer.sortBy(_.time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 1e5f18797e152..d6e15cfdd2723 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -205,10 +205,12 @@ private[streaming] class FileBasedWriteAheadLog( /** Stop the manager, close any open log writer */ def close(): Unit = synchronized { - if (currentLogWriter != null) { - currentLogWriter.close() + if (!executionContext.isShutdown) { + if (currentLogWriter != null) { + currentLogWriter.close() + } + executionContext.shutdown() } - executionContext.shutdown() logInfo("Stopped write ahead log manager") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index df122ac090c3e..c206d3169d77e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -57,6 +57,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } finally { tracker.stop(false) + // Make sure it is idempotent. + tracker.stop(false) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 4bec52b9fe4fe..ede15399f0e2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -140,6 +140,8 @@ abstract class CommonWriteAheadLogTests( } } writeAheadLog.close() + // Make sure it is idempotent. + writeAheadLog.close() } test(testPrefix + "handling file errors while reading rotating logs") { From 823f1eef580763048b08b640090519e884f29c47 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 11 Jun 2017 10:05:47 +0100 Subject: [PATCH 110/133] [SPARK-13933][BUILD] Update hadoop-2.7 profile's curator version to 2.7.1 ## What changes were proposed in this pull request? Update hadoop-2.7 profile's curator version to 2.7.1, more see [SPARK-13933](https://issues.apache.org/jira/browse/SPARK-13933). ## How was this patch tested? manual tests Author: Yuming Wang Closes #18247 from wangyum/SPARK-13933. --- dev/deps/spark-deps-hadoop-2.7 | 6 +++--- pom.xml | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ab1de3d3dd8ad..9127413ab6c23 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -47,9 +47,9 @@ commons-net-2.2.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar -curator-client-2.6.0.jar -curator-framework-2.6.0.jar -curator-recipes-2.6.0.jar +curator-client-2.7.1.jar +curator-framework-2.7.1.jar +curator-recipes-2.7.1.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar diff --git a/pom.xml b/pom.xml index 6835ea14cd42b..5f524079495c0 100644 --- a/pom.xml +++ b/pom.xml @@ -2532,6 +2532,7 @@ hadoop-2.7 2.7.3 + 2.7.1 From 9f4ff9552470fb97ca38bb56bbf43be49a9a316c Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 11 Jun 2017 03:00:44 -0700 Subject: [PATCH 111/133] [SPARK-20877][SPARKR][FOLLOWUP] clean up after test move ## What changes were proposed in this pull request? clean up after big test move ## How was this patch tested? unit tests, jenkins Author: Felix Cheung Closes #18267 from felixcheung/rtestset2. --- R/pkg/.Rbuildignore | 1 + R/pkg/R/install.R | 2 +- R/pkg/R/utils.R | 8 +- R/pkg/tests/fulltests/test_Serde.R | 6 -- R/pkg/tests/fulltests/test_Windows.R | 7 +- R/pkg/tests/fulltests/test_binaryFile.R | 8 -- R/pkg/tests/fulltests/test_binary_function.R | 6 -- R/pkg/tests/fulltests/test_broadcast.R | 4 - R/pkg/tests/fulltests/test_client.R | 8 -- R/pkg/tests/fulltests/test_context.R | 16 --- R/pkg/tests/fulltests/test_includePackage.R | 4 - .../fulltests/test_mllib_classification.R | 12 +-- R/pkg/tests/fulltests/test_mllib_clustering.R | 14 +-- R/pkg/tests/fulltests/test_mllib_fpm.R | 2 +- .../fulltests/test_mllib_recommendation.R | 2 +- R/pkg/tests/fulltests/test_mllib_regression.R | 16 +-- R/pkg/tests/fulltests/test_mllib_tree.R | 22 ++-- .../fulltests/test_parallelize_collect.R | 8 -- R/pkg/tests/fulltests/test_rdd.R | 102 ------------------ R/pkg/tests/fulltests/test_shuffle.R | 24 ----- R/pkg/tests/fulltests/test_sparkR.R | 2 - R/pkg/tests/fulltests/test_sparkSQL.R | 92 ++-------------- R/pkg/tests/fulltests/test_streaming.R | 14 +-- R/pkg/tests/fulltests/test_take.R | 2 - R/pkg/tests/fulltests/test_textFile.R | 18 ---- R/pkg/tests/fulltests/test_utils.R | 9 -- R/pkg/tests/run-all.R | 2 - 27 files changed, 35 insertions(+), 376 deletions(-) diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index f12f8c275a989..18b2db69db8f1 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -6,3 +6,4 @@ ^README\.Rmd$ ^src-native$ ^html$ +^tests/fulltests/* diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 4ca7aa664e023..ec931befa2854 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -267,7 +267,7 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context sparkCachePath <- function() { - if (.Platform$OS.type == "windows") { + if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { stop(paste("%LOCALAPPDATA% not found.", diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea45e394500e8..91483a4d23d9b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -908,10 +908,6 @@ isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } -is_cran <- function() { - !identical(Sys.getenv("NOT_CRAN"), "true") -} - is_windows <- function() { .Platform$OS.type == "windows" } @@ -920,6 +916,6 @@ hadoop_home_set <- function() { !identical(Sys.getenv("HADOOP_HOME"), "") } -not_cran_or_windows_with_hadoop <- function() { - !is_cran() && (!is_windows() || hadoop_home_set()) +windows_with_hadoop <- function() { + !is_windows() || hadoop_home_set() } diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6e160fae1afed..6bbd201bf1d82 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -20,8 +20,6 @@ context("SerDe functionality") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { - skip_on_cran() - x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -40,8 +38,6 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { - skip_on_cran() - x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -69,8 +65,6 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { - skip_on_cran() - x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/tests/fulltests/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R index 00d684e1a49ef..b2ec6c67311db 100644 --- a/R/pkg/tests/fulltests/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -17,9 +17,7 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - skip_on_cran() - - if (.Platform$OS.type != "windows") { + if (!is_windows()) { skip("This test is only for Windows, skipped") } @@ -27,6 +25,3 @@ test_that("sparkJars tag in SparkContext", { abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) - -message("--- End test (Windows) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/fulltests/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R index 00954fa31b0ee..758b174b8787c 100644 --- a/R/pkg/tests/fulltests/test_binaryFile.R +++ b/R/pkg/tests/fulltests/test_binaryFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -40,8 +38,6 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -54,8 +50,6 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -80,8 +74,6 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/tests/fulltests/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R index 236cb3885445e..442bed509bb1d 100644 --- a/R/pkg/tests/fulltests/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -29,8 +29,6 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { - skip_on_cran() - actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -53,8 +51,6 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -73,8 +69,6 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/tests/fulltests/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R index 2c96740df77bb..fc2c7c2deb825 100644 --- a/R/pkg/tests/fulltests/test_broadcast.R +++ b/R/pkg/tests/fulltests/test_broadcast.R @@ -26,8 +26,6 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcastRDD(sc, randomMat) @@ -40,8 +38,6 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { - skip_on_cran() - randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/tests/fulltests/test_client.R b/R/pkg/tests/fulltests/test_client.R index 3d53bebab6300..0cf25fe1dbf39 100644 --- a/R/pkg/tests/fulltests/test_client.R +++ b/R/pkg/tests/fulltests/test_client.R @@ -18,8 +18,6 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -28,22 +26,16 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { - skip_on_cran() - expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { - skip_on_cran() - args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f6d9f5423df02..710485d56685a 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -18,8 +18,6 @@ context("test functions in sparkR.R") test_that("Check masked functions", { - skip_on_cran() - # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -57,8 +55,6 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { - skip_on_cran() - for (i in 1:4) { sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) @@ -77,8 +73,6 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -102,8 +96,6 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - skip_on_cran() - sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -116,16 +108,12 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { - skip_on_cran() - e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -153,8 +141,6 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { - skip_on_cran() - expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -182,8 +168,6 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { - skip_on_cran() - sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/tests/fulltests/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R index d7d9eeed1575e..f4ea0d1b5cb27 100644 --- a/R/pkg/tests/fulltests/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -26,8 +26,6 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -44,8 +42,6 @@ test_that("include inside function", { }) test_that("use include package", { - skip_on_cran() - # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 82e588dc460d0..726e9d9a20b1c 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.svmLinear", { - skip_on_cran() - df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) @@ -51,7 +49,7 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -131,7 +129,7 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -228,8 +226,6 @@ test_that("spark.logit", { }) test_that("spark.mlp", { - skip_on_cran() - df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), @@ -250,7 +246,7 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -363,7 +359,7 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") write.ml(m, modelPath) expect_error(write.ml(m, modelPath)) diff --git a/R/pkg/tests/fulltests/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R index e827e961ab4c1..4110e13da4948 100644 --- a/R/pkg/tests/fulltests/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.bisectingKmeans", { - skip_on_cran() - newIris <- iris newIris$Species <- NULL training <- suppressWarnings(createDataFrame(newIris)) @@ -55,7 +53,7 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -129,7 +127,7 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -177,7 +175,7 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -244,7 +242,7 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -265,8 +263,6 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -309,8 +305,6 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { - skip_on_cran() - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/tests/fulltests/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R index 4e10ca1e4f50b..69dda52f0c279 100644 --- a/R/pkg/tests/fulltests/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -62,7 +62,7 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") write.ml(model, modelPath, overwrite = TRUE) loaded_model <- read.ml(modelPath) diff --git a/R/pkg/tests/fulltests/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R index cc8064f88d27a..4d919c9d746b0 100644 --- a/R/pkg/tests/fulltests/test_mllib_recommendation.R +++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R @@ -37,7 +37,7 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/tests/fulltests/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R index b05fdd350ca28..82472c92b9965 100644 --- a/R/pkg/tests/fulltests/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -23,8 +23,6 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -197,8 +195,6 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -226,8 +222,6 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -254,8 +248,6 @@ test_that("formula of glm", { }) test_that("glm and predict", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -300,8 +292,6 @@ test_that("glm and predict", { }) test_that("glm summary", { - skip_on_cran() - # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -351,8 +341,6 @@ test_that("glm summary", { }) test_that("glm save/load", { - skip_on_cran() - training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -401,7 +389,7 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -452,7 +440,7 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 31427ee52a5e9..9b3fc8d270b25 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -28,8 +28,6 @@ absoluteSparkPath <- function(x) { } test_that("spark.gbt", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) @@ -46,7 +44,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -80,7 +78,7 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -105,7 +103,7 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), source = "libsvm") model <- spark.gbt(data, label ~ features, "classification") @@ -144,7 +142,7 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -178,7 +176,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -215,7 +213,7 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.randomForest(data, label ~ features, "classification") @@ -224,8 +222,6 @@ test_that("spark.randomForest", { }) test_that("spark.decisionTree", { - skip_on_cran() - # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) @@ -242,7 +238,7 @@ test_that("spark.decisionTree", { expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -273,7 +269,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") write.ml(model, modelPath) expect_error(write.ml(model, modelPath)) @@ -309,7 +305,7 @@ test_that("spark.decisionTree", { expect_equal(length(grep("2.0", predictions)), 50) # spark.decisionTree classification can work on libsvm data - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.decisionTree(data, label ~ features, "classification") diff --git a/R/pkg/tests/fulltests/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R index 52d4c93ed9599..3d122ccaf448f 100644 --- a/R/pkg/tests/fulltests/test_parallelize_collect.R +++ b/R/pkg/tests/fulltests/test_parallelize_collect.R @@ -39,8 +39,6 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -68,8 +66,6 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { - skip_on_cran() - numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -90,8 +86,6 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { - skip_on_cran() - # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -101,8 +95,6 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { - skip_on_cran() - # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/tests/fulltests/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R index fb244e1d49e20..6ee1fceffd822 100644 --- a/R/pkg/tests/fulltests/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -29,30 +29,22 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - skip_on_cran() - expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { - skip_on_cran() - expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - skip_on_cran() - expect_equal(countRDD(rdd), 10) expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { - skip_on_cran() - mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -64,40 +56,30 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { - skip_on_cran() - multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { - skip_on_cran() - sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { - skip_on_cran() - sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { - skip_on_cran() - flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { - skip_on_cran() - filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -113,8 +95,6 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { - skip_on_cran() - vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -123,8 +103,6 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { - skip_on_cran() - rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -139,8 +117,6 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { - skip_on_cran() - # RDD rdd2 <- rdd # PipelinedRDD @@ -182,8 +158,6 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { - skip_on_cran() - sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -193,8 +167,6 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { - skip_on_cran() - fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -203,8 +175,6 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { - skip_on_cran() - func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -221,14 +191,10 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { - skip_on_cran() - expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { - skip_on_cran() - # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -271,8 +237,6 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { - skip_on_cran() - multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -282,8 +246,6 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { - skip_on_cran() - l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -296,8 +258,6 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { - skip_on_cran() - pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -311,8 +271,6 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { - skip_on_cran() - nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -321,29 +279,21 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { - skip_on_cran() - max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { - skip_on_cran() - min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { - skip_on_cran() - sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { - skip_on_cran() - func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -351,8 +301,6 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -374,8 +322,6 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { - skip_on_cran() - sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -387,8 +333,6 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -401,8 +345,6 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { - skip_on_cran() - l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -415,8 +357,6 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { - skip_on_cran() - actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -426,8 +366,6 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -441,8 +379,6 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -457,8 +393,6 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -473,32 +407,24 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { - skip_on_cran() - rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { - skip_on_cran() - keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { - skip_on_cran() - values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { - skip_on_cran() - actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -516,8 +442,6 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -547,8 +471,6 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -592,8 +514,6 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { - skip_on_cran() - l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -621,8 +541,6 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { - skip_on_cran() - l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -652,8 +570,6 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { - skip_on_cran() - # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -670,8 +586,6 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -696,8 +610,6 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -728,8 +640,6 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -757,8 +667,6 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { - skip_on_cran() - rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -790,8 +698,6 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - skip_on_cran() - numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -841,8 +747,6 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { - skip_on_cran() - rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -861,15 +765,11 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - skip_on_cran() - rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { - skip_on_cran() - rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -894,8 +794,6 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { - skip_on_cran() - rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/tests/fulltests/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R index 18320ea44b389..98300c67c415f 100644 --- a/R/pkg/tests/fulltests/test_shuffle.R +++ b/R/pkg/tests/fulltests/test_shuffle.R @@ -37,8 +37,6 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { - skip_on_cran() - grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -48,8 +46,6 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { - skip_on_cran() - grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -59,8 +55,6 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { - skip_on_cran() - reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -70,8 +64,6 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { - skip_on_cran() - reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -80,8 +72,6 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { - skip_on_cran() - reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -91,8 +81,6 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { - skip_on_cran() - reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -101,8 +89,6 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { - skip_on_cran() - stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -115,8 +101,6 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { - skip_on_cran() - # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -145,8 +129,6 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { - skip_on_cran() - # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -190,8 +172,6 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { - skip_on_cran() - # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -207,8 +187,6 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { - skip_on_cran() - kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -227,8 +205,6 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { - skip_on_cran() - words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/tests/fulltests/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R index a40981c188f7a..f73fc6baeccef 100644 --- a/R/pkg/tests/fulltests/test_sparkR.R +++ b/R/pkg/tests/fulltests/test_sparkR.R @@ -18,8 +18,6 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { - skip_on_cran() - # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index c790d02b107be..af529067f43e0 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -61,7 +61,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- if (not_cran_or_windows_with_hadoop()) { +sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster) } else { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) @@ -100,26 +100,20 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) -if (.Platform$OS.type == "windows") { +if (is_windows()) { Sys.setenv(TZ = "GMT") } test_that("calling sparkRSQL.init returns existing SQL context", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { - skip_on_cran() - expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { - skip_on_cran() - expect_equal(sparkR.session(), sparkSession) }) @@ -217,8 +211,6 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -316,8 +308,6 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { - skip_on_cran() - # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -330,7 +320,7 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -380,8 +370,6 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { - skip_on_cran() - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -436,8 +424,6 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { - skip_on_cran() - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -549,8 +535,6 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - skip_on_cran() - ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -563,8 +547,6 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { - skip_on_cran() - # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -607,7 +589,7 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { # Test read.df df <- read.df(jsonPath, "json") expect_is(df, "SparkDataFrame") @@ -654,8 +636,6 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -669,8 +649,6 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -730,8 +708,6 @@ test_that( }) test_that("test cache, uncache and clearCache", { - skip_on_cran() - df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -744,7 +720,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") dfParquet <- read.df(parquetPath, "parquet") @@ -787,8 +763,6 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -796,8 +770,6 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - skip_on_cran() - df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -808,8 +780,6 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { - skip_on_cran() - # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -839,8 +809,6 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { - skip_on_cran() - objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -853,8 +821,6 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - skip_on_cran() - df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -923,8 +889,6 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - skip_on_cran() - df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -964,7 +928,7 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { checkpointDir <- file.path(tempdir(), "cproot") expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) @@ -1341,7 +1305,7 @@ test_that("column calculation", { }) test_that("test HiveContext", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { setHiveContext(sc) schema <- structType(structField("name", "string"), structField("age", "integer"), @@ -1395,8 +1359,6 @@ test_that("column operators", { }) test_that("column functions", { - skip_on_cran() - c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) @@ -1782,8 +1744,6 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { }) test_that("group by, agg functions", { - skip_on_cran() - df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -2125,8 +2085,6 @@ test_that("filter() on a DataFrame", { }) test_that("join(), crossJoin() and merge() on a DataFrame", { - skip_on_cran() - df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -2400,8 +2358,6 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2423,8 +2379,6 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { - skip_on_cran() - setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2440,7 +2394,7 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - if (not_cran_or_windows_with_hadoop()) { + if (windows_with_hadoop()) { df <- read.df(jsonPath, "json") # Test write.df and read.df write.df(df, parquetPath, "parquet", mode = "overwrite") @@ -2473,8 +2427,6 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { - skip_on_cran() - df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2492,8 +2444,6 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { - skip_on_cran() - # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2515,8 +2465,6 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { - skip_on_cran() - df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2750,8 +2698,6 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { - skip_on_cran() - retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2760,8 +2706,6 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { - skip_on_cran() - expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2984,8 +2928,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { - skip_on_cran() - df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -3006,8 +2948,6 @@ test_that("dapplyCollect() on DataFrame with a binary column", { }) test_that("repartition by columns on DataFrame", { - skip_on_cran() - df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3046,8 +2986,6 @@ test_that("repartition by columns on DataFrame", { }) test_that("coalesce, repartition, numPartitions", { - skip_on_cran() - df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) expect_equal(getNumPartitions(coalesce(df, 3)), 3) @@ -3067,8 +3005,6 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - skip_on_cran() - df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -3186,8 +3122,6 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { - skip_on_cran() - sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -3221,8 +3155,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { }) test_that("randomSplit", { - skip_on_cran() - num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) weights <- c(2, 3, 5) @@ -3269,8 +3201,6 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { - skip_on_cran() - setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -3286,8 +3216,6 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { - skip_on_cran() - df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -3314,8 +3242,6 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { - skip_on_cran() - # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3440,8 +3366,6 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory - # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index b20b4312fbaae..d691de7cd725d 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -24,7 +24,7 @@ context("Structured Streaming") sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") -if (.Platform$OS.type == "windows") { +if (is_windows()) { # file.path removes the empty separator on Windows, adds it back jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) } @@ -47,8 +47,6 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -69,8 +67,6 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { }) test_that("print from explain, lastProgress, status, isActive", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -90,8 +86,6 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { - skip_on_cran() - parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -118,8 +112,6 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { - skip_on_cran() - c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -129,8 +121,6 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { - skip_on_cran() - # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -139,8 +129,6 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { - skip_on_cran() - df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, diff --git a/R/pkg/tests/fulltests/test_take.R b/R/pkg/tests/fulltests/test_take.R index c00723ba31f4c..8936cc57da227 100644 --- a/R/pkg/tests/fulltests/test_take.R +++ b/R/pkg/tests/fulltests/test_take.R @@ -34,8 +34,6 @@ sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FA sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - skip_on_cran() - numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/tests/fulltests/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R index e8a961cb3e870..be2d2711ff88e 100644 --- a/R/pkg/tests/fulltests/test_textFile.R +++ b/R/pkg/tests/fulltests/test_textFile.R @@ -24,8 +24,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -38,8 +36,6 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -50,8 +46,6 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -70,8 +64,6 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -86,8 +78,6 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -102,8 +92,6 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -115,8 +103,6 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -142,8 +128,6 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { - skip_on_cran() - fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -157,8 +141,6 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { - skip_on_cran() - fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R index 6197ae7569879..af81423aa8dd0 100644 --- a/R/pkg/tests/fulltests/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -23,7 +23,6 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { - skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -41,7 +40,6 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { - skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -169,8 +167,6 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - skip_on_cran() - method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "col", "unknown", TRUE), @@ -181,8 +177,6 @@ test_that("captureJVMException", { }) test_that("hashCode", { - skip_on_cran() - expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) @@ -243,6 +237,3 @@ test_that("basenameSansExtFromUrl", { }) sparkR.session.stop() - -message("--- End test (utils) ", as.POSIXct(Sys.time(), tz = "GMT")) -message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index d48e36c880c13..f00a610679752 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -24,8 +24,6 @@ options("warn" = 2) if (.Platform$OS.type == "windows") { Sys.setenv(TZ = "GMT") } -message("--- Start test ", as.POSIXct(Sys.time(), tz = "GMT")) -timer_ptm <- proc.time() # Setup global test environment # Install Spark first to set SPARK_HOME From 3a840048ed3501e06260b7c5df18cc0bbdb1505c Mon Sep 17 00:00:00 2001 From: sujithjay Date: Sun, 11 Jun 2017 18:23:57 +0100 Subject: [PATCH 112/133] Fixed typo in sql.functions ## What changes were proposed in this pull request? I fixed a typo in the Scaladoc for the method `def struct(cols: Column*): Column`. 'retained' was misspelt as 'remained'. ## How was this patch tested? Before: Creates a new struct column. If the input column is a column in a `DataFrame`, or a derived column expression that is named (i.e. aliased), its name would be **remained** as the StructField's name, otherwise, the newly generated StructField's name would be auto generated as `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... After: Creates a new struct column. If the input column is a column in a `DataFrame`, or a derived column expression that is named (i.e. aliased), its name would be **retained** as the StructField's name, otherwise, the newly generated StructField's name would be auto generated as `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... Author: sujithjay Closes #18254 from sujithjay/fix-typo. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8d0a8c2178803..8d2e1f32da059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1210,7 +1210,7 @@ object functions { /** * Creates a new struct column. * If the input column is a column in a `DataFrame`, or a derived column expression - * that is named (i.e. aliased), its name would be remained as the StructField's name, + * that is named (i.e. aliased), its name would be retained as the StructField's name, * otherwise, the newly generated StructField's name would be auto generated as * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ... * From a7c61c100b6e4380e8d0e588969dd7f2fd58d40c Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 12 Jun 2017 08:23:04 +0800 Subject: [PATCH 113/133] [SPARK-21031][SQL] Add `alterTableStats` to store spark's stats and let `alterTable` keep existing stats ## What changes were proposed in this pull request? Currently, hive's stats are read into `CatalogStatistics`, while spark's stats are also persisted through `CatalogStatistics`. As a result, hive's stats can be unexpectedly propagated into spark' stats. For example, for a catalog table, we read stats from hive, e.g. "totalSize" and put it into `CatalogStatistics`. Then, by using "ALTER TABLE" command, we will store the stats in `CatalogStatistics` into metastore as spark's stats (because we don't know whether it's from spark or not). But spark's stats should be only generated by "ANALYZE" command. This is unexpected from this command. Secondly, now that we have spark's stats in metastore, after inserting new data, although hive updated "totalSize" in metastore, we still cannot get the right `sizeInBytes` in `CatalogStatistics`, because we respect spark's stats (should not exist) over hive's stats. A running example is shown in [JIRA](https://issues.apache.org/jira/browse/SPARK-21031). To fix this, we add a new method `alterTableStats` to store spark's stats, and let `alterTable` keep existing stats. ## How was this patch tested? Added new tests. Author: Zhenhua Wang Closes #18248 from wzhfy/separateHiveStats. --- .../catalyst/catalog/ExternalCatalog.scala | 2 + .../catalyst/catalog/InMemoryCatalog.scala | 9 +++ .../sql/catalyst/catalog/SessionCatalog.scala | 13 +++ .../catalog/ExternalCatalogSuite.scala | 11 ++- .../catalog/SessionCatalogSuite.scala | 12 +++ .../command/AnalyzeColumnCommand.scala | 2 +- .../command/AnalyzeTableCommand.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 68 +++++++++------- .../spark/sql/hive/StatisticsSuite.scala | 80 +++++++++++-------- 9 files changed, 132 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 974ef900e2eed..12ba5aedde026 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -160,6 +160,8 @@ abstract class ExternalCatalog */ def alterTableSchema(db: String, table: String, schema: StructType): Unit + def alterTableStats(db: String, table: String, stats: CatalogStatistics): Unit + def getTable(db: String, table: String): CatalogTable def getTableOption(db: String, table: String): Option[CatalogTable] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8a5319bebe54e..9820522a230e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -312,6 +312,15 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = schema) } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = synchronized { + requireTableExists(db, table) + val origTable = catalog(db).tables(table).table + catalog(db).tables(table).table = origTable.copy(stats = Some(stats)) + } + override def getTable(db: String, table: String): CatalogTable = synchronized { requireTableExists(db, table) catalog(db).tables(table).table diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b6744a7f53a54..cf02da8993658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -376,6 +376,19 @@ class SessionCatalog( schema.fields.map(_.name).exists(conf.resolver(_, colName)) } + /** + * Alter Spark's statistics of an existing metastore table identified by the provided table + * identifier. + */ + def alterTableStats(identifier: TableIdentifier, newStats: CatalogStatistics): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + externalCatalog.alterTableStats(db, table, newStats) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 1759ac04c0033..557b0970b54e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -245,7 +245,6 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("alter table schema") { val catalog = newBasicCatalog() - val tbl1 = catalog.getTable("db2", "tbl1") val newSchema = StructType(Seq( StructField("col1", IntegerType), StructField("new_field_2", StringType), @@ -256,6 +255,16 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(newTbl1.schema == newSchema) } + test("alter table stats") { + val catalog = newBasicCatalog() + val oldTableStats = catalog.getTable("db2", "tbl1").stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats("db2", "tbl1", newStats) + val newTableStats = catalog.getTable("db2", "tbl1").stats + assert(newTableStats.get == newStats) + } + test("get table") { assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 5afeb0e8ca032..dce73b3635e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -448,6 +448,18 @@ abstract class SessionCatalogSuite extends PlanTest { } } + test("alter table stats") { + withBasicCatalog { catalog => + val tableId = TableIdentifier("tbl1", Some("db2")) + val oldTableStats = catalog.getTableMetadata(tableId).stats + assert(oldTableStats.isEmpty) + val newStats = CatalogStatistics(sizeInBytes = 1) + catalog.alterTableStats(tableId, newStats) + val newTableStats = catalog.getTableMetadata(tableId).stats + assert(newTableStats.get == newStats) + } + } + test("alter table add columns") { withBasicCatalog { sessionCatalog => sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 2de14c90ec757..2f273b63e8348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -54,7 +54,7 @@ case class AnalyzeColumnCommand( // Newly computed column stats should override the existing ones. colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColStats) - sessionState.catalog.alterTable(tableMeta.copy(stats = Some(statistics))) + sessionState.catalog.alterTableStats(tableIdentWithDB, statistics) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3183c7911b1fb..3c59b982c2dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -69,7 +69,7 @@ case class AnalyzeTableCommand( // Update the metastore if the above statistics of the table are different from those // recorded in the metastore. if (newStats.isDefined) { - sessionState.catalog.alterTable(tableMeta.copy(stats = newStats)) + sessionState.catalog.alterTableStats(tableIdentWithDB, newStats.get) // Refresh the cached data source table in the catalog. sessionState.catalog.refreshTable(tableIdentWithDB) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 918459fe7c246..7fcf06d66b5ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -527,7 +527,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * Alter a table whose name that matches the one specified in `tableDefinition`, - * assuming the table exists. + * assuming the table exists. This method does not change the properties for data source and + * statistics. * * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. @@ -538,30 +539,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, tableDefinition.identifier.table) verifyTableProperties(tableDefinition) - // convert table statistics to properties so that we can persist them through hive api - val withStatsProps = if (tableDefinition.stats.isDefined) { - val stats = tableDefinition.stats.get - var statsProperties: Map[String, String] = - Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) - if (stats.rowCount.isDefined) { - statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() - } - val colNameTypeMap: Map[String, DataType] = - tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap - stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) - } - } - tableDefinition.copy(properties = tableDefinition.properties ++ statsProperties) - } else { - tableDefinition - } - if (tableDefinition.tableType == VIEW) { - client.alterTable(withStatsProps) + client.alterTable(tableDefinition) } else { - val oldTableDef = getRawTable(db, withStatsProps.identifier.table) + val oldTableDef = getRawTable(db, tableDefinition.identifier.table) val newStorage = if (DDLUtils.isHiveTable(tableDefinition)) { tableDefinition.storage @@ -611,12 +592,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat TABLE_PARTITION_PROVIDER -> TABLE_PARTITION_PROVIDER_FILESYSTEM } - // Sets the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, - // to retain the spark specific format if it is. Also add old data source properties to table - // properties, to retain the data source table format. - val oldDataSourceProps = oldTableDef.properties.filter(_._1.startsWith(DATASOURCE_PREFIX)) - val newTableProps = oldDataSourceProps ++ withStatsProps.properties + partitionProviderProp - val newDef = withStatsProps.copy( + // Add old data source properties to table properties, to retain the data source table format. + // Add old stats properties to table properties, to retain spark's stats. + // Set the `schema`, `partitionColumnNames` and `bucketSpec` from the old table definition, + // to retain the spark specific format if it is. + val propsFromOldTable = oldTableDef.properties.filter { case (k, v) => + k.startsWith(DATASOURCE_PREFIX) || k.startsWith(STATISTICS_PREFIX) + } + val newTableProps = propsFromOldTable ++ tableDefinition.properties + partitionProviderProp + val newDef = tableDefinition.copy( storage = newStorage, schema = oldTableDef.schema, partitionColumnNames = oldTableDef.partitionColumnNames, @@ -647,6 +631,32 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + override def alterTableStats( + db: String, + table: String, + stats: CatalogStatistics): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + + // convert table statistics to properties so that we can persist them through hive client + var statsProperties: Map[String, String] = + Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString()) + if (stats.rowCount.isDefined) { + statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() + } + val colNameTypeMap: Map[String, DataType] = + rawTable.schema.fields.map(f => (f.name, f.dataType)).toMap + stats.colStats.foreach { case (colName, colStat) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => + statsProperties += (columnStatKeyPropName(colName, k) -> v) + } + } + + val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX)) + val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties) + client.alterTable(updatedTable) + } + override def getTable(db: String, table: String): CatalogTable = withClient { restoreTableMetadata(getRawTable(db, table)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 5d52f8baa3b94..001bbc230ff18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -25,7 +25,7 @@ import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -267,7 +267,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("get statistics when not analyzed in both Hive and Spark") { + test("get statistics when not analyzed in Hive or Spark") { val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) @@ -313,60 +313,70 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - test("alter table SET TBLPROPERTIES after analyze table") { - Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" - withTable(tabName) { - createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('foo' = 'a')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) - - val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") - - val totalSize = extractStatsPropValues(describeResult, "totalSize") - assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + test("alter table should not have the side effect to store statistics in Spark side") { + def getCatalogTable(tableName: String): CatalogTable = { + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + } - // ALTER TABLE SET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API - val numRows = extractStatsPropValues(describeResult, "numRows") - assert(numRows.isDefined && numRows.get == -1, "numRows is lost") - val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") - assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") - } + val table = "alter_table_side_effect" + withTable(table) { + sql(s"CREATE TABLE $table (i string, j string)") + sql(s"INSERT INTO TABLE $table SELECT 'a', 'b'") + val catalogTable1 = getCatalogTable(table) + val hiveSize1 = BigInt(catalogTable1.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + + sql(s"ALTER TABLE $table SET TBLPROPERTIES ('prop1' = 'a')") + + sql(s"INSERT INTO TABLE $table SELECT 'c', 'd'") + val catalogTable2 = getCatalogTable(table) + val hiveSize2 = BigInt(catalogTable2.ignoredProperties(StatsSetupConst.TOTAL_SIZE)) + // After insertion, Hive's stats should be changed. + assert(hiveSize2 > hiveSize1) + // We haven't generate stats in Spark, so we should still use Hive's stats here. + assert(catalogTable2.stats.get.sizeInBytes == hiveSize2) } } - test("alter table UNSET TBLPROPERTIES after analyze table") { + private def testAlterTableProperties(tabName: String, alterTablePropCmd: String): Unit = { Seq(true, false).foreach { analyzedBySpark => - val tabName = "tab1" withTable(tabName) { createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) - val fetchedStats1 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - sql(s"ALTER TABLE $tabName UNSET TBLPROPERTIES ('prop1')") - val fetchedStats2 = checkTableStats( - tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) - assert(fetchedStats1 == fetchedStats2) + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + + // Run ALTER TABLE command + sql(alterTablePropCmd) val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") val totalSize = extractStatsPropValues(describeResult, "totalSize") assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") - // ALTER TABLE UNSET TBLPROPERTIES invalidates some Hive specific statistics - // This is triggered by the Hive alterTable API + // ALTER TABLE SET/UNSET TBLPROPERTIES invalidates some Hive specific statistics, but not + // Spark specific statistics. This is triggered by the Hive alterTable API. val numRows = extractStatsPropValues(describeResult, "numRows") assert(numRows.isDefined && numRows.get == -1, "numRows is lost") val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + + if (analyzedBySpark) { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } else { + checkTableStats(tabName, hasSizeInBytes = true, expectedRowCounts = None) + } } } } + test("alter table SET TBLPROPERTIES after analyze table") { + testAlterTableProperties("set_prop_table", + "ALTER TABLE set_prop_table SET TBLPROPERTIES ('foo' = 'a')") + } + + test("alter table UNSET TBLPROPERTIES after analyze table") { + testAlterTableProperties("unset_prop_table", + "ALTER TABLE unset_prop_table UNSET TBLPROPERTIES ('prop1')") + } + test("add/drop partitions - managed table") { val catalog = spark.sessionState.catalog val managedTable = "partitionedTable" From 0538f3b0ae4b80750ab81b210ad6fe77178337bf Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Mon, 12 Jun 2017 08:47:01 +0800 Subject: [PATCH 114/133] [SPARK-18891][SQL] Support for Scala Map collection types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder. Used the builder approach as in #16541 to construct any provided `Map` type upon deserialization. Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it. Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes: * `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap` * `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap` * `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap` * `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap` Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private scala.collection.immutable.Map mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ scala.collection.immutable.Map deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ /* 112 */ scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder(); /* 113 */ CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength); /* 114 */ /* 115 */ int deserializetoobject_loopIndex = 0; /* 116 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 117 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 118 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 119 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 120 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 121 */ /* 122 */ if (CollectObjectsToMap_loopIsNull1) { /* 123 */ throw new RuntimeException("Found null in map key!"); /* 124 */ } /* 125 */ /* 126 */ scala.Tuple2 CollectObjectsToMap_loopValue4; /* 127 */ /* 128 */ if (CollectObjectsToMap_loopIsNull3) { /* 129 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null); /* 130 */ } else { /* 131 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 132 */ } /* 133 */ /* 134 */ CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4); /* 135 */ /* 136 */ deserializetoobject_loopIndex += 1; /* 137 */ } /* 138 */ /* 139 */ deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result(); /* 140 */ } /* 141 */ /* 142 */ boolean mapelements_isNull = true; /* 143 */ scala.collection.immutable.Map mapelements_value = null; /* 144 */ if (!false) { /* 145 */ mapelements_argValue = deserializetoobject_value; /* 146 */ /* 147 */ mapelements_isNull = false; /* 148 */ if (!mapelements_isNull) { /* 149 */ Object mapelements_funcResult = null; /* 150 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 151 */ if (mapelements_funcResult == null) { /* 152 */ mapelements_isNull = true; /* 153 */ } else { /* 154 */ mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult; /* 155 */ } /* 156 */ /* 157 */ } /* 158 */ mapelements_isNull = mapelements_value == null; /* 159 */ } /* 160 */ /* 161 */ MapData serializefromobject_value = null; /* 162 */ if (!mapelements_isNull) { /* 163 */ final int serializefromobject_length = mapelements_value.size(); /* 164 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 165 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 166 */ int serializefromobject_index = 0; /* 167 */ final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator(); /* 168 */ while(serializefromobject_entries.hasNext()) { /* 169 */ final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next(); /* 170 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1(); /* 171 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2(); /* 172 */ /* 173 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 174 */ /* 175 */ if (false) { /* 176 */ throw new RuntimeException("Cannot use null as map key!"); /* 177 */ } else { /* 178 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 179 */ } /* 180 */ /* 181 */ if (false) { /* 182 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 183 */ } else { /* 184 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 185 */ } /* 186 */ /* 187 */ serializefromobject_index++; /* 188 */ } /* 189 */ /* 190 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 191 */ } /* 192 */ serializefromobject_holder.reset(); /* 193 */ /* 194 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 195 */ /* 196 */ if (mapelements_isNull) { /* 197 */ serializefromobject_rowWriter.setNullAt(0); /* 198 */ } else { /* 199 */ // Remember the current cursor so that we can calculate how many bytes are /* 200 */ // written later. /* 201 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 202 */ /* 203 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 204 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 205 */ // grow the global buffer before writing data. /* 206 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 207 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 208 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 209 */ /* 210 */ } else { /* 211 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 212 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 213 */ /* 214 */ // preserve 8 bytes to write the key array numBytes later. /* 215 */ serializefromobject_holder.grow(8); /* 216 */ serializefromobject_holder.cursor += 8; /* 217 */ /* 218 */ // Remember the current cursor so that we can write numBytes of key array later. /* 219 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 220 */ /* 221 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 222 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 223 */ // grow the global buffer before writing data. /* 224 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 225 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 226 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 227 */ /* 228 */ } else { /* 229 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 230 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 231 */ /* 232 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 233 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 234 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 235 */ } else { /* 236 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 237 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 238 */ } /* 239 */ } /* 240 */ } /* 241 */ /* 242 */ // Write the numBytes of key array into the first 8 bytes. /* 243 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 244 */ /* 245 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 246 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 247 */ // grow the global buffer before writing data. /* 248 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 249 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 250 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 251 */ /* 252 */ } else { /* 253 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 254 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 255 */ /* 256 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 257 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 258 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 259 */ } else { /* 260 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 261 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 262 */ } /* 263 */ } /* 264 */ } /* 265 */ /* 266 */ } /* 267 */ /* 268 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 269 */ } /* 270 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 271 */ append(serializefromobject_result); /* 272 */ if (shouldStop()) return; /* 273 */ } /* 274 */ } /* 275 */ } ``` Codegen for `java.util.Map`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private java.util.HashMap mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ java.util.HashMap deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength); /* 112 */ /* 113 */ int deserializetoobject_loopIndex = 0; /* 114 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 115 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 116 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 117 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 118 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 119 */ /* 120 */ if (CollectObjectsToMap_loopIsNull1) { /* 121 */ throw new RuntimeException("Found null in map key!"); /* 122 */ } /* 123 */ /* 124 */ CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 125 */ /* 126 */ deserializetoobject_loopIndex += 1; /* 127 */ } /* 128 */ /* 129 */ deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5; /* 130 */ } /* 131 */ /* 132 */ boolean mapelements_isNull = true; /* 133 */ java.util.HashMap mapelements_value = null; /* 134 */ if (!false) { /* 135 */ mapelements_argValue = deserializetoobject_value; /* 136 */ /* 137 */ mapelements_isNull = false; /* 138 */ if (!mapelements_isNull) { /* 139 */ Object mapelements_funcResult = null; /* 140 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 141 */ if (mapelements_funcResult == null) { /* 142 */ mapelements_isNull = true; /* 143 */ } else { /* 144 */ mapelements_value = (java.util.HashMap) mapelements_funcResult; /* 145 */ } /* 146 */ /* 147 */ } /* 148 */ mapelements_isNull = mapelements_value == null; /* 149 */ } /* 150 */ /* 151 */ MapData serializefromobject_value = null; /* 152 */ if (!mapelements_isNull) { /* 153 */ final int serializefromobject_length = mapelements_value.size(); /* 154 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 155 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 156 */ int serializefromobject_index = 0; /* 157 */ final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator(); /* 158 */ while(serializefromobject_entries.hasNext()) { /* 159 */ final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next(); /* 160 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey(); /* 161 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue(); /* 162 */ /* 163 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 164 */ /* 165 */ if (false) { /* 166 */ throw new RuntimeException("Cannot use null as map key!"); /* 167 */ } else { /* 168 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 169 */ } /* 170 */ /* 171 */ if (false) { /* 172 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 173 */ } else { /* 174 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 175 */ } /* 176 */ /* 177 */ serializefromobject_index++; /* 178 */ } /* 179 */ /* 180 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 181 */ } /* 182 */ serializefromobject_holder.reset(); /* 183 */ /* 184 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 185 */ /* 186 */ if (mapelements_isNull) { /* 187 */ serializefromobject_rowWriter.setNullAt(0); /* 188 */ } else { /* 189 */ // Remember the current cursor so that we can calculate how many bytes are /* 190 */ // written later. /* 191 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 192 */ /* 193 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 194 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 195 */ // grow the global buffer before writing data. /* 196 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 197 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 198 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 199 */ /* 200 */ } else { /* 201 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 202 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 203 */ /* 204 */ // preserve 8 bytes to write the key array numBytes later. /* 205 */ serializefromobject_holder.grow(8); /* 206 */ serializefromobject_holder.cursor += 8; /* 207 */ /* 208 */ // Remember the current cursor so that we can write numBytes of key array later. /* 209 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 210 */ /* 211 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 212 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 213 */ // grow the global buffer before writing data. /* 214 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 215 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 216 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 217 */ /* 218 */ } else { /* 219 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 220 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 221 */ /* 222 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 223 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 224 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 225 */ } else { /* 226 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 227 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 228 */ } /* 229 */ } /* 230 */ } /* 231 */ /* 232 */ // Write the numBytes of key array into the first 8 bytes. /* 233 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 234 */ /* 235 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 236 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 237 */ // grow the global buffer before writing data. /* 238 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 239 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 240 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 241 */ /* 242 */ } else { /* 243 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 244 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 245 */ /* 246 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 247 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 248 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 249 */ } else { /* 250 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 251 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 252 */ } /* 253 */ } /* 254 */ } /* 255 */ /* 256 */ } /* 257 */ /* 258 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 259 */ } /* 260 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 261 */ append(serializefromobject_result); /* 262 */ if (shouldStop()) return; /* 263 */ } /* 264 */ } /* 265 */ } ``` ## How was this patch tested? ``` build/mvn -DskipTests clean package && dev/run-tests ``` Additionally in Spark shell: ``` scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect() res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4)) ``` Author: Michal Senkyr Author: Michal Šenkýř Closes #16986 from michalsenkyr/dataset-map-builder. --- .../spark/sql/catalyst/ScalaReflection.scala | 33 +--- .../expressions/objects/objects.scala | 169 +++++++++++++++++- .../sql/catalyst/ScalaReflectionSuite.scala | 25 +++ .../org/apache/spark/sql/SQLImplicits.scala | 5 + .../spark/sql/DatasetPrimitiveSuite.scala | 86 +++++++++ 5 files changed, 291 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 87130532c89bc..d580cf4d3391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyData = - Invoke( - MapObjects( - p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - val valueData = - Invoke( - MapObjects( - p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), - returnNullable = false), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[scala.collection.immutable.Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) + CollectObjectsToMap( + p => deserializerFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), + getPath, + mirror.runtimeClass(t.typeSymbol.asClass) + ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ecf745c9..79b7b9f3d0e16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -652,6 +652,173 @@ case class MapObjects private( } } +object CollectObjectsToMap { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjectsToMap case class. + * + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ + def apply( + keyFunction: Expression => Expression, + valueFunction: Expression => Expression, + inputData: Expression, + collClass: Class[_]): CollectObjectsToMap = { + val id = curId.getAndIncrement() + val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) + val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" + val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) + CollectObjectsToMap( + keyLoopValue, keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), + inputData, collClass) + } +} + +/** + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopValue the name of the loop variable that is used when iterating over the key + * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueLoopValue the name of the loop variable that is used when iterating over the value + * collection, and which is used as input for the `valueLambdaFunction` + * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ +case class CollectObjectsToMap private( + keyLoopValue: String, + keyLambdaFunction: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLambdaFunction: Expression, + inputData: Expression, + collClass: Class[_]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = + keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(dataType: DataType) = dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => dataType + } + + val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] + val keyElementJavaType = ctx.javaType(mapType.keyType) + ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = ctx.javaType(mapType.valueType) + ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") + + val getLength = s"${genInputData.value}.numElements()" + + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val getKeyArray = + s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getValueArray = + s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + int $dataLength = $getLength; + $constructBuilder + $getKeyArray + $getValueArray + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $appendToBuilder + + $loopIndex += 1; + } + + $getBuilderResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064f93ebc..ff2414b174acb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary map types") { + val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + assert(mapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mapDeserializer = deserializerFor[Map[Int, Int]] + assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) + + import scala.collection.immutable.HashMap + val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + assert(hashMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) + + import scala.collection.mutable.{LinkedHashMap => LHMap} + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + assert(linkedHashMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 17671ea8685b9..86574e2f71d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Maps + /** @since 2.3.0 */ + implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 7e2949ab5aece..4126660b5d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + package object packageobject { case class PackageClass(value: Int) } @@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) } + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) From f48273c13c9e9fea2d9bb6dda10fcaaaaa50c588 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Mon, 12 Jun 2017 08:53:23 +0800 Subject: [PATCH 115/133] [SPARK-18891][SQL] Support for specific Java List subtypes ## What changes were proposed in this pull request? Add support for specific Java `List` subtypes in deserialization as well as a generic implicit encoder. All `List` subtypes are supported by using either the size-specifying constructor (one `int` parameter) or the default constructor. Interfaces/abstract classes use the following implementations: * `java.util.List`, `java.util.AbstractList` or `java.util.AbstractSequentialList` => `java.util.ArrayList` ## How was this patch tested? ```bash build/mvn -DskipTests clean package && dev/run-tests ``` Additionally in Spark shell: ``` scala> val jlist = new java.util.LinkedList[Int]; jlist.add(1) jlist: java.util.LinkedList[Int] = [1] res0: Boolean = true scala> Seq(jlist).toDS().map(_.element()).collect() res1: Array[Int] = Array(1) ``` Author: Michal Senkyr Closes #18009 from michalsenkyr/dataset-java-lists. --- .../sql/catalyst/JavaTypeInference.scala | 15 ++--- .../expressions/objects/objects.scala | 19 +++++- .../apache/spark/sql/JavaDatasetSuite.java | 61 +++++++++++++++++++ 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319ec3f..7683ee7074e7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -267,16 +267,11 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val array = - Invoke( - MapObjects( - p => deserializerFor(et, Some(p)), - getPath, - inferDataType(et)._1), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1, + customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 79b7b9f3d0e16..5bb0febc943f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -597,8 +598,8 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) => - // collection + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( @@ -609,6 +610,20 @@ case class MapObjects private( genValue => s"$builder.$$plus$$eq($genValue);", s"(${cls.getName}) $builder.result();" ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + genValue => s"$builder.add($genValue);", + s"$builder;" + ) case None => // array ( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc8b4..4ca3b6406a328 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1399,4 +1399,65 @@ public void testSerializeNull() { ds1.map((MapFunction) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List beans = Collections.singletonList(bean); + Dataset dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList arrayList; + private LinkedList linkedList; + private List list; + + public ArrayList getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList arrayList) { + this.arrayList = arrayList; + } + + public LinkedList getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList linkedList) { + this.linkedList = linkedList; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } } From 3476390c6e5d0fcfff340410f57e114039b5fbd4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 11 Jun 2017 18:34:12 -0700 Subject: [PATCH 116/133] [SPARK-20715] Store MapStatuses only in MapOutputTracker, not ShuffleMapStage ## What changes were proposed in this pull request? This PR refactors `ShuffleMapStage` and `MapOutputTracker` in order to simplify the management of `MapStatuses`, reduce driver memory consumption, and remove a potential source of scheduler correctness bugs. ### Background In Spark there are currently two places where MapStatuses are tracked: - The `MapOutputTracker` maintains an `Array[MapStatus]` storing a single location for each map output. This mapping is used by the `DAGScheduler` for determining reduce-task locality preferences (when locality-aware reduce task scheduling is enabled) and is also used to serve map output locations to executors / tasks. - Each `ShuffleMapStage` also contains a mapping of `Array[List[MapStatus]]` which holds the complete set of locations where each map output could be available. This mapping is used to determine which map tasks need to be run when constructing `TaskSets` for the stage. This duplication adds complexity and creates the potential for certain types of correctness bugs. Bad things can happen if these two copies of the map output locations get out of sync. For instance, if the `MapOutputTracker` is missing locations for a map output but `ShuffleMapStage` believes that locations are available then tasks will fail with `MetadataFetchFailedException` but `ShuffleMapStage` will not be updated to reflect the missing map outputs, leading to situations where the stage will be reattempted (because downstream stages experienced fetch failures) but no task sets will be launched (because `ShuffleMapStage` thinks all maps are available). I observed this behavior in a real-world deployment. I'm still not quite sure how the state got out of sync in the first place, but we can completely avoid this class of bug if we eliminate the duplicate state. ### Why we only need to track a single location for each map output I think that storing an `Array[List[MapStatus]]` in `ShuffleMapStage` is unnecessary. First, note that this adds memory/object bloat to the driver we need one extra `List` per task. If you have millions of tasks across all stages then this can add up to be a significant amount of resources. Secondly, I believe that it's extremely uncommon that these lists will ever contain more than one entry. It's not impossible, but is very unlikely given the conditions which must occur for that to happen: - In normal operation (no task failures) we'll only run each task once and thus will have at most one output. - If speculation is enabled then it's possible that we'll have multiple attempts of a task. The TaskSetManager will [kill duplicate attempts of a task](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L717) after a task finishes successfully, reducing the likelihood that both the original and speculated task will successfully register map outputs. - There is a [comment in `TaskSetManager`](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L113) which suggests that running tasks are not killed if a task set becomes a zombie. However: - If the task set becomes a zombie due to the job being cancelled then it doesn't matter whether we record map outputs. - If the task set became a zombie because of a stage failure (e.g. the map stage itself had a fetch failure from an upstream match stage) then I believe that the "failedEpoch" will be updated which may cause map outputs from still-running tasks to [be ignored](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1213). (I'm not 100% sure on this point, though). - Even if you _do_ manage to record multiple map outputs for a stage, only a single map output is reported to / tracked by the MapOutputTracker. The only situation where the additional output locations could actually be read or used would be if a task experienced a `FetchFailure` exception. The most likely cause of a `FetchFailure` exception is an executor lost, which will have most likely caused the loss of several map tasks' output, so saving on potential re-execution of a single map task isn't a huge win if we're going to have to recompute several other lost map outputs from other tasks which ran on that lost executor. Also note that the re-population of MapOutputTracker state from state in the ShuffleMapTask only happens after the reduce stage has failed; the additional location doesn't help to prevent FetchFailures but, instead, can only reduce the amount of work when recomputing missing parent stages. Given this, this patch chooses to do away with tracking multiple locations for map outputs and instead stores only a single location. This change removes the main distinction between the `ShuffleMapTask` and `MapOutputTracker`'s copies of this state, paving the way for storing it only in the `MapOutputTracker`. ### Overview of other changes - Significantly simplified the cache / lock management inside of the `MapOutputTrackerMaster`: - The old code had several parallel `HashMap`s which had to be guarded by maps of `Object`s which were used as locks. This code was somewhat complicated to follow. - The new code uses a new `ShuffleStatus` class to group together all of the state associated with a particular shuffle, including cached serialized map statuses, significantly simplifying the logic. - Moved more code out of the shared `MapOutputTracker` abstract base class and into the `MapOutputTrackerMaster` and `MapOutputTrackerWorker` subclasses. This makes it easier to reason about which functionality needs to be supported only on the driver or executor. - Removed a bunch of code from the `DAGScheduler` which was used to synchronize information from the `MapOutputTracker` to `ShuffleMapStage`. - Added comments to clarify the role of `MapOutputTrackerMaster`'s `epoch` in invalidating executor-side shuffle map output caches. I will comment on these changes via inline GitHub review comments. /cc hvanhovell and rxin (whom I discussed this with offline), tgravescs (who recently worked on caching of serialized MapOutputStatuses), and kayousterhout and markhamstra (for scheduler changes). ## How was this patch tested? Existing tests. I purposely avoided making interface / API which would require significant updates or modifications to test code. Author: Josh Rosen Closes #17955 from JoshRosen/map-output-tracker-rewrite. --- .../org/apache/spark/MapOutputTracker.scala | 636 ++++++++++-------- .../org/apache/spark/executor/Executor.scala | 10 +- .../apache/spark/scheduler/DAGScheduler.scala | 51 +- .../spark/scheduler/ShuffleMapStage.scala | 76 +-- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 6 +- .../scala/org/apache/spark/ShuffleSuite.scala | 3 +- .../scheduler/BlacklistTrackerSuite.scala | 3 +- 8 files changed, 398 insertions(+), 389 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef6656222455..3e10b9eee4e24 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus(numPartitions: Int) { + + // All accesses to the following state must be guarded with `this.synchronized`. + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + private[this] val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapId) == null) { + _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapId) = status + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle outputs. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + cachedSerializedBroadcast.destroy() + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val mapStatuses: Map[Int, Array[MapStatus]] - - /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) @@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** - * Return statistics about all of the outputs for a given shuffle. + * Deletes map output status information for the specified shuffle stage. */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } + def unregisterShuffle(shuffleId: Int): Unit - if (fetchedStatuses == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatuses != null) { - return fetchedStatuses - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - return statuses - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() - } - } - } - - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - } - - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) - if (changeEpoch) { - incrementEpoch() - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + def unregisterShuffle(shuffleId: Int) { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch } } - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) + // This method is only called in local-mode. + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } - } - bytes + case None => + Seq.empty } } @@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = + + val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5b396687dd11a..19e7eb086f413 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -322,8 +322,14 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ab2255f8a6654..932e6c138e1c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -328,25 +328,14 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") @@ -1217,7 +1206,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1234,16 +1224,14 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1343,7 +1331,6 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } @@ -1393,17 +1380,7 @@ class DAGScheduler( if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } + mapOutputTracker.removeOutputsOnExecutor(execId) clearCacheLocs() } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270c..05f650fbf5df9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage( /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f3033e28b47d0..629cfc7c7a8ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 4fe5c5e4fee4a..bc3d23e3fbb29 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -139,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 622f7985ba444..3931d53b4ae0a 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, @@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2b..571c6bbb4585d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } From d1409180932f2658daad2c6dbf5d80fdf4606dc5 Mon Sep 17 00:00:00 2001 From: liuxian Date: Sun, 11 Jun 2017 22:29:09 -0700 Subject: [PATCH 117/133] [SPARK-20665][SQL][FOLLOW-UP] Move test case to MathExpressionsSuite ## What changes were proposed in this pull request? add test case to MathExpressionsSuite as #17906 ## How was this patch tested? unit test cases Author: liuxian Closes #18082 from 10110346/wip-lx-0524. --- .../expressions/MathExpressionsSuite.scala | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6af0cde73538b..f4d5a4471d896 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -23,6 +23,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts.implicitCast import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer @@ -223,6 +224,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def f: (Double) => Double = (x: Double) => 1 / math.tan(x) testUnary(Cot, f) checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType) + val nullLit = Literal.create(null, NullType) + val intNullLit = Literal.create(null, IntegerType) + val intLit = Literal.create(1, IntegerType) + checkEvaluation(checkDataTypeAndCast(Cot(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(intLit)), 1 / math.tan(1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(-intLit)), 1 / math.tan(-1), EmptyRow) + checkEvaluation(checkDataTypeAndCast(Cot(0)), 1 / math.tan(0), EmptyRow) } test("atan") { @@ -250,6 +259,11 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } + def checkDataTypeAndCast(expression: UnaryMathExpression): Expression = { + val expNew = implicitCast(expression.child, expression.inputTypes(0)).getOrElse(expression) + expression.withNewChildren(Seq(expNew)) + } + test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) @@ -262,12 +276,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Ceil(doublePi), 4L, EmptyRow) - checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) - checkEvaluation(Ceil(longLit), longLit, EmptyRow) - checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) - checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) - checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Ceil(doublePi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatPi)), 4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-doublePi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-floatPi)), -3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Ceil(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(0.01)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Ceil(-0.10)), 0L, EmptyRow) } test("floor") { @@ -282,12 +306,22 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doublePi: Double = 3.1415 val floatPi: Float = 3.1415f val longLit: Long = 12345678901234567L - checkEvaluation(Floor(doublePi), 3L, EmptyRow) - checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) - checkEvaluation(Floor(longLit), longLit, EmptyRow) - checkEvaluation(Floor(-doublePi), -4L, EmptyRow) - checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) - checkEvaluation(Floor(-longLit), -longLit, EmptyRow) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + checkEvaluation(checkDataTypeAndCast(Floor(doublePi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatPi)), 3L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(longLit)), longLit, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-doublePi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-floatPi)), -4L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-longLit)), -longLit, EmptyRow) + + checkEvaluation(checkDataTypeAndCast(Floor(nullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(floatNullLit)), null, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1)), 1L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(1234567890123456L)), 1234567890123456L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(0.01)), 0L, EmptyRow) + checkEvaluation(checkDataTypeAndCast(Floor(-0.10)), -1L, EmptyRow) } test("factorial") { @@ -541,10 +575,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intPi: Int = 314159265 val longPi: Long = 31415926535897932L val bdPi: BigDecimal = BigDecimal(31415927L, 7) + val floatPi: Float = 3.1415f val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593) + val floatResults: Seq[Float] = Seq(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 3.1f, 3.14f, + 3.141f, 3.1415f, 3.1415f, 3.1415f) + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ Seq.fill[Short](7)(31415) @@ -563,10 +601,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow) checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(floatPi, scale), floatResults(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), From e6eb02df1540764ef2a4f0edb45c48df8de18c13 Mon Sep 17 00:00:00 2001 From: Ziyue Huang Date: Mon, 12 Jun 2017 10:59:33 +0100 Subject: [PATCH 118/133] [DOCS] Fix error: ambiguous reference to overloaded definition ## What changes were proposed in this pull request? `df.groupBy.count()` should be `df.groupBy().count()` , otherwise there is an error : ambiguous reference to overloaded definition, both method groupBy in class Dataset of type (col1: String, cols: String*) and method groupBy in class Dataset of type (cols: org.apache.spark.sql.Column*) ## How was this patch tested? ```scala val df = spark.readStream.schema(...).json(...) val dfCounts = df.groupBy().count() ``` Author: Ziyue Huang Closes #18272 from ZiyueHuang/master. --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 6a25c9939c264..9b9177d44145f 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1056,7 +1056,7 @@ Some of them are as follows. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). -- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy.count()` which returns a streaming Dataset containing a running count. +- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy().count()` which returns a streaming Dataset containing a running count. - `foreach()` - Instead use `ds.writeStream.foreach(...)` (see next section). From a92e095e705155ea10c8311f7856b964d654626a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Jun 2017 20:58:27 +0800 Subject: [PATCH 119/133] [SPARK-21041][SQL] SparkSession.range should be consistent with SparkContext.range ## What changes were proposed in this pull request? This PR fixes the inconsistency in `SparkSession.range`. **BEFORE** ```scala scala> spark.range(java.lang.Long.MAX_VALUE - 3, java.lang.Long.MIN_VALUE + 2, 1).collect res2: Array[Long] = Array(9223372036854775804, 9223372036854775805, 9223372036854775806) ``` **AFTER** ```scala scala> spark.range(java.lang.Long.MAX_VALUE - 3, java.lang.Long.MIN_VALUE + 2, 1).collect res2: Array[Long] = Array() ``` ## How was this patch tested? Pass the Jenkins with newly added test cases. Author: Dongjoon Hyun Closes #18257 from dongjoon-hyun/SPARK-21041. --- .../spark/sql/execution/basicPhysicalOperators.scala | 10 +++++++--- .../org/apache/spark/sql/DataFrameRangeSuite.scala | 11 +++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index f69a688555bbf..04c130314388a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} @@ -347,8 +347,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } override def inputRDDs(): Seq[RDD[InternalRow]] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) :: Nil + val rdd = if (start == end || (start < end ^ 0 < step)) { + new EmptyRDD[InternalRow](sqlContext.sparkContext) + } else { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + rdd :: Nil } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 7b495656b93d7..45afbd29d1907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -191,6 +191,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) } } + + test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + val start = java.lang.Long.MAX_VALUE - 3 + val end = java.lang.Long.MIN_VALUE + 2 + Seq("false", "true").foreach { value => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) + } + } + } } object DataFrameRangeSuite { From 22dd65f58e12cb3a883d106fcccdff25a2a00fe8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 13 Jun 2017 00:12:34 +0800 Subject: [PATCH 120/133] [SPARK-21046][SQL] simplify the array offset and length in ColumnVector ## What changes were proposed in this pull request? Currently when a `ColumnVector` stores array type elements, we will use 2 arrays for lengths and offsets and implement them individually in on-heap and off-heap column vector. In this PR, we use one array to represent both offsets and lengths, so that we can treat it as `ColumnVector` and all the logic can go to the base class `ColumnVector` ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #18260 from cloud-fan/put. --- .../execution/vectorized/ColumnVector.java | 35 ++++++------- .../vectorized/OffHeapColumnVector.java | 47 ++---------------- .../vectorized/OnHeapColumnVector.java | 49 +++---------------- .../vectorized/ColumnarBatchSuite.scala | 17 ++++--- 4 files changed, 38 insertions(+), 110 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 24260a60197f2..e50799eeb27ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.execution.vectorized; import java.math.BigDecimal; @@ -518,19 +519,13 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract double getDouble(int rowId); /** - * Puts a byte array that already exists in this column. - */ - public abstract void putArray(int rowId, int offset, int length); - - /** - * Returns the length of the array at rowid. + * After writing array elements to the child column vector, call this method to set the offset and + * size of the written array. */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array at rowid. - */ - public abstract int getArrayOffset(int rowId); + public void putArrayOffsetAndSize(int rowId, int offset, int size) { + long offsetAndSize = (((long) offset) << 32) | size; + putLong(rowId, offsetAndSize); + } /** * Returns a utility object to get structs. @@ -553,8 +548,9 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { * Returns the array at rowid. */ public final Array getArray(int rowId) { - resultArray.length = getArrayLength(rowId); - resultArray.offset = getArrayOffset(rowId); + long offsetAndSize = getLong(rowId); + resultArray.offset = (int) (offsetAndSize >> 32); + resultArray.length = (int) offsetAndSize; return resultArray; } @@ -566,7 +562,12 @@ public final Array getArray(int rowId) { /** * Sets the value at rowId to `value`. */ - public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + putArrayOffsetAndSize(rowId, result, length); + return result; + } + public final int putByteArray(int rowId, byte[] value) { return putByteArray(rowId, value, 0, value.length); } @@ -829,13 +830,13 @@ public final int appendDoubles(int length, double[] src, int offset) { public final int appendByteArray(byte[] value, int offset, int length) { int copiedOffset = arrayData().appendBytes(length, value, offset); reserve(elementsAppended + 1); - putArray(elementsAppended, copiedOffset, length); + putArrayOffsetAndSize(elementsAppended, copiedOffset, length); return elementsAppended++; } public final int appendArray(int length) { reserve(elementsAppended + 1); - putArray(elementsAppended, arrayData().elementsAppended, length); + putArrayOffsetAndSize(elementsAppended, arrayData().elementsAppended, length); return elementsAppended++; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7d3744d00e91..4dc4d34db37fb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -34,19 +34,15 @@ public final class OffHeapColumnVector extends ColumnVector { // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. private long nulls; + // The actually data of this column vector will be stored here. If it's an array column vector, + // we will store the offsets and lengths here, and store the element data in child column vector. private long data; - // Set iff the type is array. - private long lengthData; - private long offsetData; - protected OffHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.OFF_HEAP); nulls = 0; data = 0; - lengthData = 0; - offsetData = 0; reserveInternal(capacity); reset(); @@ -66,12 +62,8 @@ public long nullsNativeAddress() { public void close() { Platform.freeMemory(nulls); Platform.freeMemory(data); - Platform.freeMemory(lengthData); - Platform.freeMemory(offsetData); nulls = 0; data = 0; - lengthData = 0; - offsetData = 0; } // @@ -395,35 +387,6 @@ public double getDouble(int rowId) { } } - // - // APIs dealing with Arrays. - // - @Override - public void putArray(int rowId, int offset, int length) { - assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); - } - - @Override - public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); - } - - @Override - public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); - } - - // APIs dealing with ByteArrays - @Override - public int putByteArray(int rowId, byte[] value, int offset, int length) { - int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); - return result; - } - @Override public void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; @@ -438,10 +401,8 @@ public void loadBytes(ColumnVector.Array array) { protected void reserveInternal(int newCapacity) { int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { - this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); - this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + // need a long as offset and length for each array. + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 94ed32294cfae..4d23405dc7b17 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -43,14 +43,12 @@ public final class OnHeapColumnVector extends ColumnVector { private byte[] byteData; private short[] shortData; private int[] intData; + // This is not only used to store data for long column vector, but also can store offsets and + // lengths for array column vector. private long[] longData; private float[] floatData; private double[] doubleData; - // Only set if type is Array. - private int[] arrayLengths; - private int[] arrayOffsets; - protected OnHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.ON_HEAP); reserveInternal(capacity); @@ -366,55 +364,22 @@ public double getDouble(int rowId) { } } - // - // APIs dealing with Arrays - // - - @Override - public int getArrayLength(int rowId) { - return arrayLengths[rowId]; - } - @Override - public int getArrayOffset(int rowId) { - return arrayOffsets[rowId]; - } - - @Override - public void putArray(int rowId, int offset, int length) { - arrayOffsets[rowId] = offset; - arrayLengths[rowId] = length; - } - @Override public void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } - // - // APIs dealing with Byte Arrays - // - - @Override - public int putByteArray(int rowId, byte[] value, int offset, int length) { - int result = arrayData().appendBytes(length, value, offset); - arrayOffsets[rowId] = result; - arrayLengths[rowId] = length; - return result; - } - // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { - int[] newLengths = new int[newCapacity]; - int[] newOffsets = new int[newCapacity]; - if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); + // need 1 long as offset and length for each array. + if (longData == null || longData.length < newCapacity) { + long[] newData = new long[newCapacity]; + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); + longData = newData; } - arrayLengths = newLengths; - arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e48e3f6402901..5c4128a70dd86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -631,7 +631,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. - column.putArray(idx, offset, 2) + column.putArrayOffsetAndSize(idx, offset, 2) reference += "ll" idx += 1 assert(column.arrayData().elementsAppended == 17) @@ -644,7 +644,8 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => - assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) + val offsetAndLength = column.getLong(v._2) + assert(v._1.length == offsetAndLength.toInt, "MemoryMode=" + memMode) assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -659,7 +660,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) // Fill the underlying data with all the arrays back to back. - val data = column.arrayData(); + val data = column.arrayData() var i = 0 while (i < 6) { data.putInt(i, i) @@ -667,10 +668,10 @@ class ColumnarBatchSuite extends SparkFunSuite { } // Populate it with arrays [0], [1, 2], [], [3, 4, 5] - column.putArray(0, 0, 1) - column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putArrayOffsetAndSize(0, 0, 1) + column.putArrayOffsetAndSize(1, 1, 2) + column.putArrayOffsetAndSize(2, 3, 0) + column.putArrayOffsetAndSize(3, 3, 3) val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] @@ -703,7 +704,7 @@ class ColumnarBatchSuite extends SparkFunSuite { data.reserve(array.length) assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) - column.putArray(0, 0, array.length) + column.putArrayOffsetAndSize(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) }} From ca4e960aec1a5f8cdc1e1344a25840d2670de391 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 12 Jun 2017 13:06:14 -0700 Subject: [PATCH 121/133] [SPARK-17914][SQL] Fix parsing of timestamp strings with nanoseconds The PR contains a tiny change to fix the way Spark parses string literals into timestamps. Currently, some timestamps that contain nanoseconds are corrupted during the conversion from internal UTF8Strings into the internal representation of timestamps. Consider the following example: ``` spark.sql("SELECT cast('2015-01-02 00:00:00.000000001' as TIMESTAMP)").show(false) +------------------------------------------------+ |CAST(2015-01-02 00:00:00.000000001 AS TIMESTAMP)| +------------------------------------------------+ |2015-01-02 00:00:00.000001 | +------------------------------------------------+ ``` The fix was tested with existing tests. Also, there is a new test to cover cases that did not work previously. Author: aokolnychyi Closes #18252 from aokolnychyi/spark-17914. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 13 +++++++------ .../sql/catalyst/util/DateTimeUtilsSuite.scala | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index efb42292634ad..746c3e8950f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp - * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * and are stored internally as longs, which are capable of storing timestamps with microsecond * precision. */ object DateTimeUtils { @@ -399,13 +399,14 @@ object DateTimeUtils { digitsMilli += 1 } - if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { - return None + // We are truncating the nanosecond part, which results in loss of precision + while (digitsMilli > 6) { + segments(6) /= 10 + digitsMilli -= 1 } - // Instead of return None, we truncate the fractional seconds to prevent inserting NULL - if (segments(6) > 999999) { - segments(6) = segments(6).toString.take(6).toInt + if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { + return None } if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 9799817494f15..c8cf16d937352 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -34,6 +34,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } + test("nanoseconds truncation") { + def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) { + val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime)) + assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly") + assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime) + } + + checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456") + checkStringToTimestamp("2015-01-02 00:00:00.100000009", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000050000", "2015-01-02 00:00:00.00005") + checkStringToTimestamp("2015-01-02 00:00:00.12005", "2015-01-02 00:00:00.12005") + checkStringToTimestamp("2015-01-02 00:00:00.100", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000456789", "2015-01-02 00:00:00.000456") + checkStringToTimestamp("1950-01-02 00:00:00.000456789", "1950-01-02 00:00:00.000456") + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) From 32818d9b378f98556cff529af8645382cd5c6d16 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Jun 2017 14:05:03 -0700 Subject: [PATCH 122/133] [SPARK-20345][SQL] Fix STS error handling logic on HiveSQLException ## What changes were proposed in this pull request? [SPARK-5100](https://github.com/apache/spark/commit/343d3bfafd449a0371feb6a88f78e07302fa7143) added Spark Thrift Server(STS) UI and the following logic to handle exceptions on case `Throwable`. ```scala HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, SparkUtils.exceptionString(e)) ``` However, there occurred a missed case after implementing [SPARK-6964](https://github.com/apache/spark/commit/eb19d3f75cbd002f7e72ce02017a8de67f562792)'s `Support Cancellation in the Thrift Server` by adding case `HiveSQLException` before case `Throwable`. ```scala case e: HiveSQLException => if (getStatus().getState() == OperationState.CANCELED) { return } else { setState(OperationState.ERROR) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => val currentState = getStatus().getState() logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, SparkUtils.exceptionString(e)) throw new HiveSQLException(e.toString) ``` Logically, we had better add `HiveThriftServer2.listener.onStatementError` on case `HiveSQLException`, too. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #17643 from dongjoon-hyun/SPARK-20345. --- .../sql/hive/thriftserver/SparkExecuteStatementOperation.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ff3784cab9e26..1d1074a2a7387 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation( return } else { setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and From b1436c7496161da1f4fa6950a06de62c9007c40c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 12 Jun 2017 14:07:51 -0700 Subject: [PATCH 123/133] [SPARK-21059][SQL] LikeSimplification can NPE on null pattern ## What changes were proposed in this pull request? This patch fixes a bug that can cause NullPointerException in LikeSimplification, when the pattern for like is null. ## How was this patch tested? Added a new unit test case in LikeSimplificationSuite. Author: Reynold Xin Closes #18273 from rxin/SPARK-21059. --- .../sql/catalyst/optimizer/expressions.scala | 37 +++++++++++-------- .../optimizer/LikeSimplificationSuite.scala | 8 +++- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 51f749a8bf857..66b8ca62e5e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -383,22 +383,27 @@ object LikeSimplification extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(input, Literal(pattern, StringType)) => - pattern.toString match { - case startsWith(prefix) if !prefix.endsWith("\\") => - StartsWith(input, Literal(prefix)) - case endsWith(postfix) => - EndsWith(input, Literal(postfix)) - // 'a%a' pattern is basically same with 'a%' && '%a'. - // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => - And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith("\\") => - Contains(input, Literal(infix)) - case equalTo(str) => - EqualTo(input, Literal(str)) - case _ => - Like(input, Literal.create(pattern, StringType)) + if (pattern == null) { + // If pattern is null, return null value directly, since "col like null" == null. + Literal(null, BooleanType) + } else { + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) + case _ => + Like(input, Literal.create(pattern, StringType)) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index fdde89d079bc0..50398788c605c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, StringType} class LikeSimplificationSuite extends PlanTest { @@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("null pattern") { + val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze) + } } From ff318c0d2f283c3f46491f229f82d93714da40c7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 12 Jun 2017 14:27:57 -0700 Subject: [PATCH 124/133] [SPARK-21050][ML] Word2vec persistence overflow bug fix ## What changes were proposed in this pull request? The method calculateNumberOfPartitions() uses Int, not Long (unlike the MLlib version), so it is very easily to have an overflow in calculating the number of partitions for ML persistence. This modifies the calculations to use Long. ## How was this patch tested? New unit test. I verified that the test fails before this patch. Author: Joseph K. Bradley Closes #18265 from jkbradley/word2vec-save-fix. --- .../apache/spark/ml/feature/Word2Vec.scala | 38 ++++++++++++++----- .../spark/ml/feature/Word2VecSuite.scala | 10 +++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 4ca062c0b5adf..b6909b3386b71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -339,25 +340,42 @@ object Word2VecModel extends MLReadable[Word2VecModel] { val wordVectors = instance.wordVectors.getVectors val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString + val bufferSizeInBytes = Utils.byteStringAsBytes( + sc.conf.get("spark.kryoserializer.buffer.max", "64m")) + val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( + bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) sparkSession.createDataFrame(dataSeq) - .repartition(calculateNumberOfPartitions) + .repartition(numPartitions) .write .parquet(dataPath) } + } - def calculateNumberOfPartitions(): Int = { - val floatSize = 4 + private[feature] + object Word2VecModelWriter { + /** + * Calculate the number of partitions to use in saving the model. + * [SPARK-11994] - We want to partition the model in partitions smaller than + * spark.kryoserializer.buffer.max + * @param bufferSizeInBytes Set to spark.kryoserializer.buffer.max + * @param numWords Vocab size + * @param vectorSize Vector length for each word + */ + def calculateNumberOfPartitions( + bufferSizeInBytes: Long, + numWords: Int, + vectorSize: Int): Int = { + val floatSize = 4L // Use Long to help avoid overflow val averageWordSize = 15 - // [SPARK-11994] - We want to partition the model in partitions smaller than - // spark.kryoserializer.buffer.max - val bufferSizeInBytes = Utils.byteStringAsBytes( - sc.conf.get("spark.kryoserializer.buffer.max", "64m")) // Calculate the approximate size of the model. // Assuming an average word size of 15 bytes, the formula is: // (floatSize * vectorSize + 15) * numWords - val numWords = instance.wordVectors.wordIndex.size - val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords - ((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt + val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords + val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1 + require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " + + s"partitions to save this model, which is too large. Try increasing " + + s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.") + numPartitions.toInt } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a6a1c2b4f32bd..6183606a7b2ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -188,6 +189,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } + test("Word2Vec read/write numPartitions calculation") { + val smallModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 10, vectorSize = 5) + assert(smallModelNumPartitions === 1) + val largeModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions( + Utils.byteStringAsBytes("64m"), numWords = 1000000, vectorSize = 5000) + assert(largeModelNumPartitions > 1) + } + test("Word2Vec read/write") { val t = new Word2Vec() .setInputCol("myInputCol") From 74a432d3a39ee6fe16889b76d9729926299d5492 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 12 Jun 2017 14:58:08 -0700 Subject: [PATCH 125/133] [SPARK-20979][SS] Add RateSource to generate values for tests and benchmark ## What changes were proposed in this pull request? This PR adds RateSource for Structured Streaming so that the user can use it to generate data for tests and benchmark easily. This source generates increment long values with timestamps. Each generated row has two columns: a timestamp column for the generated time and an auto increment long column starting with 0L. It supports the following options: - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer seconds. - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the generated rows. The source will try its best to reach `rowsPerSecond`, but the query may be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. Here is a simple example that prints 10 rows per seconds: ``` spark.readStream .format("rate") .option("rowsPerSecond", "10") .load() .writeStream .format("console") .start() ``` The idea came from marmbrus and he did the initial work. ## How was this patch tested? The added tests. Author: Shixiong Zhu Author: Michael Armbrust Closes #18199 from zsxwing/rate. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../streaming/RateSourceProvider.scala | 243 ++++++++++++++++++ .../execution/streaming/RateSourceSuite.scala | 182 +++++++++++++ .../spark/sql/streaming/StreamTest.scala | 3 + 4 files changed, 429 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 27d32b5dca431..0c5f3f22e31e8 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..e61a8eb628891 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = + (shortName(), RateSourceProvider.SCHEMA) + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..bdba536425a43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d1..2a4039cc5831a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -172,8 +172,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { * * @param isFatalError if this is a fatal error. If so, the error should also be caught by * UncaughtExceptionHandler. + * @param assertFailure a function to verify the error. */ case class ExpectFailure[T <: Throwable : ClassTag]( + assertFailure: Throwable => Unit = _ => {}, isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] override def toString(): String = @@ -455,6 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") streamThreadDeathCause = null } + ef.assertFailure(exception.getCause) } catch { case _: InterruptedException => case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => From fc0e6944a531c806cf4fd83141646f264fa19af3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 13 Jun 2017 09:15:14 +0800 Subject: [PATCH 126/133] Revert "[SPARK-21046][SQL] simplify the array offset and length in ColumnVector" This reverts commit 22dd65f58e12cb3a883d106fcccdff25a2a00fe8. --- .../execution/vectorized/ColumnVector.java | 35 +++++++------ .../vectorized/OffHeapColumnVector.java | 47 ++++++++++++++++-- .../vectorized/OnHeapColumnVector.java | 49 ++++++++++++++++--- .../vectorized/ColumnarBatchSuite.scala | 17 +++---- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index e50799eeb27ba..24260a60197f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.execution.vectorized; import java.math.BigDecimal; @@ -519,13 +518,19 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract double getDouble(int rowId); /** - * After writing array elements to the child column vector, call this method to set the offset and - * size of the written array. + * Puts a byte array that already exists in this column. */ - public void putArrayOffsetAndSize(int rowId, int offset, int size) { - long offsetAndSize = (((long) offset) << 32) | size; - putLong(rowId, offsetAndSize); - } + public abstract void putArray(int rowId, int offset, int length); + + /** + * Returns the length of the array at rowid. + */ + public abstract int getArrayLength(int rowId); + + /** + * Returns the offset of the array at rowid. + */ + public abstract int getArrayOffset(int rowId); /** * Returns a utility object to get structs. @@ -548,9 +553,8 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { * Returns the array at rowid. */ public final Array getArray(int rowId) { - long offsetAndSize = getLong(rowId); - resultArray.offset = (int) (offsetAndSize >> 32); - resultArray.length = (int) offsetAndSize; + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); return resultArray; } @@ -562,12 +566,7 @@ public final Array getArray(int rowId) { /** * Sets the value at rowId to `value`. */ - public int putByteArray(int rowId, byte[] value, int offset, int length) { - int result = arrayData().appendBytes(length, value, offset); - putArrayOffsetAndSize(rowId, result, length); - return result; - } - + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { return putByteArray(rowId, value, 0, value.length); } @@ -830,13 +829,13 @@ public final int appendDoubles(int length, double[] src, int offset) { public final int appendByteArray(byte[] value, int offset, int length) { int copiedOffset = arrayData().appendBytes(length, value, offset); reserve(elementsAppended + 1); - putArrayOffsetAndSize(elementsAppended, copiedOffset, length); + putArray(elementsAppended, copiedOffset, length); return elementsAppended++; } public final int appendArray(int length) { reserve(elementsAppended + 1); - putArrayOffsetAndSize(elementsAppended, arrayData().elementsAppended, length); + putArray(elementsAppended, arrayData().elementsAppended, length); return elementsAppended++; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4dc4d34db37fb..a7d3744d00e91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -34,15 +34,19 @@ public final class OffHeapColumnVector extends ColumnVector { // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. private long nulls; - // The actually data of this column vector will be stored here. If it's an array column vector, - // we will store the offsets and lengths here, and store the element data in child column vector. private long data; + // Set iff the type is array. + private long lengthData; + private long offsetData; + protected OffHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.OFF_HEAP); nulls = 0; data = 0; + lengthData = 0; + offsetData = 0; reserveInternal(capacity); reset(); @@ -62,8 +66,12 @@ public long nullsNativeAddress() { public void close() { Platform.freeMemory(nulls); Platform.freeMemory(data); + Platform.freeMemory(lengthData); + Platform.freeMemory(offsetData); nulls = 0; data = 0; + lengthData = 0; + offsetData = 0; } // @@ -387,6 +395,35 @@ public double getDouble(int rowId) { } } + // + // APIs dealing with Arrays. + // + @Override + public void putArray(int rowId, int offset, int length) { + assert(offset >= 0 && offset + length <= childColumns[0].capacity); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, offset); + } + + @Override + public int getArrayLength(int rowId) { + return Platform.getInt(null, lengthData + 4 * rowId); + } + + @Override + public int getArrayOffset(int rowId) { + return Platform.getInt(null, offsetData + 4 * rowId); + } + + // APIs dealing with ByteArrays + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, result); + return result; + } + @Override public void loadBytes(ColumnVector.Array array) { if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; @@ -401,8 +438,10 @@ public void loadBytes(ColumnVector.Array array) { protected void reserveInternal(int newCapacity) { int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { - // need a long as offset and length for each array. - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.lengthData = + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + this.offsetData = + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 4d23405dc7b17..94ed32294cfae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -43,12 +43,14 @@ public final class OnHeapColumnVector extends ColumnVector { private byte[] byteData; private short[] shortData; private int[] intData; - // This is not only used to store data for long column vector, but also can store offsets and - // lengths for array column vector. private long[] longData; private float[] floatData; private double[] doubleData; + // Only set if type is Array. + private int[] arrayLengths; + private int[] arrayOffsets; + protected OnHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.ON_HEAP); reserveInternal(capacity); @@ -364,22 +366,55 @@ public double getDouble(int rowId) { } } + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + return arrayLengths[rowId]; + } + @Override + public int getArrayOffset(int rowId) { + return arrayOffsets[rowId]; + } + + @Override + public void putArray(int rowId, int offset, int length) { + arrayOffsets[rowId] = offset; + arrayLengths[rowId] = length; + } + @Override public void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; array.byteArrayOffset = array.offset; } + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + arrayOffsets[rowId] = result; + arrayLengths[rowId] = length; + return result; + } + // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { - // need 1 long as offset and length for each array. - if (longData == null || longData.length < newCapacity) { - long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); - longData = newData; + int[] newLengths = new int[newCapacity]; + int[] newOffsets = new int[newCapacity]; + if (this.arrayLengths != null) { + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } + arrayLengths = newLengths; + arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 5c4128a70dd86..e48e3f6402901 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -631,7 +631,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. - column.putArrayOffsetAndSize(idx, offset, 2) + column.putArray(idx, offset, 2) reference += "ll" idx += 1 assert(column.arrayData().elementsAppended == 17) @@ -644,8 +644,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => - val offsetAndLength = column.getLong(v._2) - assert(v._1.length == offsetAndLength.toInt, "MemoryMode=" + memMode) + assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -660,7 +659,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) // Fill the underlying data with all the arrays back to back. - val data = column.arrayData() + val data = column.arrayData(); var i = 0 while (i < 6) { data.putInt(i, i) @@ -668,10 +667,10 @@ class ColumnarBatchSuite extends SparkFunSuite { } // Populate it with arrays [0], [1, 2], [], [3, 4, 5] - column.putArrayOffsetAndSize(0, 0, 1) - column.putArrayOffsetAndSize(1, 1, 2) - column.putArrayOffsetAndSize(2, 3, 0) - column.putArrayOffsetAndSize(3, 3, 3) + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 0) + column.putArray(3, 3, 3) val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] @@ -704,7 +703,7 @@ class ColumnarBatchSuite extends SparkFunSuite { data.reserve(array.length) assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) - column.putArrayOffsetAndSize(0, 0, array.length) + column.putArray(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) }} From 2639c3ed03075d37f07042a03d93a4237366c6a5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Jun 2017 21:18:43 -0700 Subject: [PATCH 127/133] [SPARK-19910][SQL] `stack` should not reject NULL values due to type mismatch ## What changes were proposed in this pull request? Since `stack` function generates a table with nullable columns, it should allow mixed null values. ```scala scala> sql("select stack(3, 1, 2, 3)").printSchema root |-- col0: integer (nullable = true) scala> sql("select stack(3, 1, 2, null)").printSchema org.apache.spark.sql.AnalysisException: cannot resolve 'stack(3, 1, 2, NULL)' due to data type mismatch: Argument 1 (IntegerType) != Argument 3 (NullType); line 1 pos 7; ``` ## How was this patch tested? Pass the Jenkins with a new test case. Author: Dongjoon Hyun Closes #17251 from dongjoon-hyun/SPARK-19910. --- .../sql/catalyst/analysis/TypeCoercion.scala | 17 ++++++ .../sql/catalyst/expressions/generators.scala | 19 +++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 57 +++++++++++++++++++ .../spark/sql/GeneratorFunctionSuite.scala | 4 ++ 4 files changed, 97 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37a95..1f217390518a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: IfCoercion :: + StackCoercion :: Division :: PropagateTypes :: ImplicitTypeCasts :: @@ -648,6 +649,22 @@ object TypeCoercion { } } + /** + * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. + */ + object StackCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => + Stack(children.zipWithIndex.map { + // The first child is the number of rows for stack. + case (e, 0) => e + case (Literal(null, NullType), index: Int) => + Literal.create(null, s.findDataType(index)) + case (e, _) => e + }) + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e84796f2edad0..e023f0567ea87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -138,6 +138,13 @@ case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + /** + * Return true iff the first child exists and has a foldable IntegerType. + */ + def hasFoldableNumRows: Boolean = { + children.nonEmpty && children.head.dataType == IntegerType && children.head.foldable + } + override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") @@ -156,6 +163,18 @@ case class Stack(children: Seq[Expression]) extends Generator { } } + def findDataType(index: Int): DataType = { + // Find the first data type except NullType. + val firstDataIndex = ((index - 1) % numFields) + 1 + for (i <- firstDataIndex until children.length by numFields) { + if (children(i).dataType != NullType) { + return children(i).dataType + } + } + // If all values of the column are NullType, use it. + NullType + } + override def elementSchema: StructType = StructType(children.tail.take(numFields).zipWithIndex.map { case (e, index) => StructField(s"col$index", e.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2ac11598e63d1..7358f401ed520 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -768,6 +768,63 @@ class TypeCoercionSuite extends PlanTest { ) } + test("type coercion for Stack") { + val rule = TypeCoercion.StackCoercion + + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))), + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))), + Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))), + Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3")))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))), + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal(null), + Literal(null), Literal("2"))), + Stack(Seq(Literal(2), + Literal(1), Literal.create(null, StringType), + Literal.create(null, IntegerType), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(1), + Literal("2"), Literal(null))), + Stack(Seq(Literal(2), + Literal.create(null, StringType), Literal(1), + Literal("2"), Literal.create(null, IntegerType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(null), + Literal(1), Literal("2"))), + Stack(Seq(Literal(2), + Literal.create(null, IntegerType), Literal.create(null, StringType), + Literal(1), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 539c63d3cb288..6b98209fd49b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + // Null values + checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"), + Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil) + // Repeat generation at every input row checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) From 278ba7a2c62b2cbb7bcfe79ce10d35ab57bb1950 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 12 Jun 2017 22:08:49 -0700 Subject: [PATCH 128/133] [TEST][SPARKR][CORE] Fix broken SparkSubmitSuite ## What changes were proposed in this pull request? Fix test file path. This is broken in #18264 and undetected since R-only changes don't build core and subsequent post-commit with the change built fine (again because it wasn't building core) actually appveyor builds everything but it's not running scala suites ... ## How was this patch tested? jenkins srowen gatorsmile Author: Felix Cheung Closes #18283 from felixcheung/rsubmitsuite. --- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index de719990cf47a..b089357e7b868 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -505,8 +505,8 @@ class SparkSubmitSuite assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "packageInAJarTest.R").mkString(File.separator) + val rScriptDir = Seq( + sparkHome, "R", "pkg", "tests", "fulltests", "packageInAJarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) IvyTestUtils.withRepository(main, None, None, withR = true) { repo => val args = Seq( @@ -527,7 +527,7 @@ class SparkSubmitSuite // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = - Seq(sparkHome, "R", "pkg", "inst", "tests", "testthat", "jarTest.R").mkString(File.separator) + Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) // compile a small jar containing a class that will be called from R code. From 7b7c85ede398996aafffb126440e5f0c67f67210 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 13 Jun 2017 10:48:07 +0100 Subject: [PATCH 129/133] [SPARK-20920][SQL] ForkJoinPool pools are leaked when writing hive tables with many partitions ## What changes were proposed in this pull request? Don't leave thread pool running from AlterTableRecoverPartitionsCommand DDL command ## How was this patch tested? Existing tests. Author: Sean Owen Closes #18216 from srowen/SPARK-20920. --- .../spark/sql/execution/command/ddl.scala | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 793fb9b795596..5a7f8cf1eb59e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -21,7 +21,6 @@ import java.util.Locale import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.forkjoin.ForkJoinPool import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -588,8 +587,15 @@ case class AlterTableRecoverPartitionsCommand( val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) - val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(), - table.partitionColumnNames, threshold, spark.sessionState.conf.resolver) + + val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) + val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = + try { + scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, + spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + } finally { + evalPool.shutdown() + } val total = partitionSpecsAndLocs.length logInfo(s"Found $total partitions in $root") @@ -610,8 +616,6 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty[Row] } - @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) - private def scanPartitions( spark: SparkSession, fs: FileSystem, @@ -620,7 +624,8 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = { + resolver: Resolver, + evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } @@ -644,7 +649,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver) + partitionNames.drop(1), threshold, resolver, evalTaskSupport) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") From 2aaed0a4db84e99186b52a2c49d532702b575406 Mon Sep 17 00:00:00 2001 From: liuxian Date: Tue, 13 Jun 2017 12:29:50 +0100 Subject: [PATCH 130/133] [SPARK-21006][TESTS][FOLLOW-UP] Some Worker's RpcEnv is leaked in WorkerSuite ## What changes were proposed in this pull request? Create rpcEnv and run later needs shutdown. as #18226 ## How was this patch tested? unit test Author: liuxian Closes #18259 from 10110346/wip-lx-0610. --- .../spark/deploy/worker/WorkerSuite.scala | 39 ++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 101a44edd8ee2..ce212a7513310 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.worker -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, ExecutorState} @@ -25,7 +25,7 @@ import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorState import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} -class WorkerSuite extends SparkFunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ @@ -34,6 +34,25 @@ class WorkerSuite extends SparkFunSuite with Matchers { } def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) + private var _worker: Worker = _ + + private def makeWorker(conf: SparkConf): Worker = { + assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) + _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "Worker", "/tmp", conf, securityMgr) + _worker + } + + after { + if (_worker != null) { + _worker.rpcEnv.shutdown() + _worker.rpcEnv.awaitTermination() + _worker = null + } + } + test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=true")) shouldBe true @@ -65,9 +84,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (small number of executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -91,9 +108,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedExecutors (more executors)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedExecutors", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -126,9 +141,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (small number of drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 2.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -152,9 +165,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { test("test clearing of finishedDrivers (more drivers)") { val conf = new SparkConf() conf.set("spark.worker.ui.retainedDrivers", 30.toString) - val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, new SecurityManager(conf)) + val worker = makeWorker(conf) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" From 9b2c877beccf34fc7c063574496be7e6281227ad Mon Sep 17 00:00:00 2001 From: Rishabh Bhardwaj Date: Tue, 13 Jun 2017 15:09:12 +0100 Subject: [PATCH 131/133] [SPARK-21039][SPARK CORE] Use treeAggregate instead of aggregate in DataFrame.stat.bloomFilter ## What changes were proposed in this pull request? To use treeAggregate instead of aggregate in DataFrame.stat.bloomFilter to parallelize the operation of merging the bloom filters (Please fill in changes proposed in this fix) ## How was this patch tested? unit tests passed (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Rishabh Bhardwaj Author: Rishabh Bhardwaj Author: Rishabh Bhardwaj Author: Rishabh Bhardwaj Author: Rishabh Bhardwaj Closes #18263 from rishabhbhardwaj/SPARK-21039. --- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index c856d3099f6ee..531c613afb0dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -551,7 +551,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { ) } - singleCol.queryExecution.toRdd.aggregate(zero)( + singleCol.queryExecution.toRdd.treeAggregate(zero)( (filter: BloomFilter, row: InternalRow) => { updater(filter, row) filter From b7304f25590fb1bc65cba6440d59a3322fd09947 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Jun 2017 15:38:11 +0100 Subject: [PATCH 132/133] [SPARK-21060][WEB-UI] Css style about paging function is error in the executor page. Css style about paging function is error in the executor page. It is different of history server ui paging function css style. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Css style about paging function is error in the executor page. It is different of history server ui paging function css style. **But their style should be consistent**. There are three reasons. 1. The first reason: 'Previous', 'Next' and number should be the button format. 2. The second reason: when you are on the first page, 'Previous' and '1' should be gray and can not be clicked. ![1](https://user-images.githubusercontent.com/26266482/27026667-1fe745ee-4f91-11e7-8b34-150819d22bd3.png) 3. The third reason: when you are on the last page, 'Previous' and 'Max number' should be gray and can not be clicked. ![2](https://user-images.githubusercontent.com/26266482/27026811-9d8d6fa0-4f91-11e7-8b51-7816c3feb381.png) before fix: ![fix_before](https://user-images.githubusercontent.com/26266482/27026428-47ec5c56-4f90-11e7-9dd5-d52c22d7bd36.png) after fix: ![fix_after](https://user-images.githubusercontent.com/26266482/27026439-50d17072-4f90-11e7-8405-6f81da5ab32c.png) The style of history server ui: ![history](https://user-images.githubusercontent.com/26266482/27026528-9c90f780-4f90-11e7-91e6-90d32651fe03.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Author: 郭小龙 10207633 Author: guoxiaolongzte Closes #18275 from guoxiaolongzte/SPARK-21060. --- .../src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b7cbed468517c..d63381c78bc3b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
++ -
++ + ++ ++ ++ From b36ce2a2469ff923a3367a530d4a14899ecf9238 Mon Sep 17 00:00:00 2001 From: DjvuLee Date: Tue, 13 Jun 2017 15:56:03 +0100 Subject: [PATCH 133/133] [SPARK-21064][CORE][TEST] Fix the default value bug in NettyBlockTransferServiceSuite ## What changes were proposed in this pull request? The default value for `spark.port.maxRetries` is 100, but we use 10 in the suite file. So we change it to 100 to avoid test failure. ## How was this patch tested? No test Author: DjvuLee Closes #18280 from djvulee/NettyTestBug. --- .../spark/network/netty/NettyBlockTransferServiceSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 271ab8b148831..98259300381eb 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,7 +80,8 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests - actualPort should be <= (expectedPort + 10) + // the default value for `spark.port.maxRetries` is 100 under test + actualPort should be <= (expectedPort + 100) } private def createService(port: Int): NettyBlockTransferService = {